udion commited on
Commit
99e984c
1 Parent(s): 5b3ab4e

hfspace gradio demo

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,2 @@
1
- ---
2
- title: BayesCap
3
- emoji: 🦀
4
- colorFrom: purple
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 3.0.24
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ # BayesCap
2
+ Bayesian Identity Cap for Calibrated Uncertainty in Pretrained Neural Networks
 
 
 
 
 
 
 
 
 
 
 
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib import cm
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.models as models
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode as IMode
13
+
14
+ from PIL import Image
15
+
16
+ from ds import *
17
+ from losses import *
18
+ from networks_SRGAN import *
19
+ from utils import *
20
+
21
+ device = 'cuda'
22
+
23
+
24
+ NetG = Generator()
25
+ model_parameters = filter(lambda p: True, NetG.parameters())
26
+ params = sum([np.prod(p.size()) for p in model_parameters])
27
+ print("Number of Parameters:", params)
28
+ NetC = BayesCap(in_channels=3, out_channels=3)
29
+
30
+ ensure_checkpoint_exists('BayesCap_SRGAN.pth')
31
+ NetG.load_state_dict(torch.load('BayesCap_SRGAN.pth', map_location=device))
32
+ NetG.to(device)
33
+ NetG.eval()
34
+
35
+ ensure_checkpoint_exists('BayesCap_ckpt.pth')
36
+ NetC.load_state_dict(torch.load('BayesCap_ckpt.pth', map_location=device))
37
+ NetC.to(device)
38
+ NetC.eval()
39
+
40
+ def tensor01_to_pil(xt):
41
+ r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
42
+ return r
43
+
44
+
45
+ def predict(img):
46
+ """
47
+ img: image
48
+ """
49
+ image_size = (256,256)
50
+ upscale_factor = 4
51
+ lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
52
+ # lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
53
+
54
+ img = Image.fromarray(np.array(img))
55
+ img = lr_transforms(img)
56
+ lr_tensor = utils.image2tensor(img, range_norm=False, half=False)
57
+
58
+ device = 'cuda'
59
+ dtype = torch.cuda.FloatTensor
60
+ xLR = lr_tensor.to(device).unsqueeze(0)
61
+ xLR = xLR.type(dtype)
62
+ # pass them through the network
63
+ with torch.no_grad():
64
+ xSR = NetG(xLR)
65
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
66
+
67
+ a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
68
+ b_map = xSRC_beta[0].to('cpu').data
69
+ u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
70
+
71
+
72
+ x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
73
+
74
+ x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
75
+
76
+ #im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
77
+
78
+ a_map = torch.clamp(a_map, min=0, max=0.1)
79
+ a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
80
+ x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
81
+
82
+ b_map = torch.clamp(b_map, min=0.45, max=0.75)
83
+ b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
84
+ x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
85
+
86
+ u_map = torch.clamp(u_map, min=0, max=0.15)
87
+ u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
88
+ x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
89
+
90
+ return x_LR, x_mean, x_alpha, x_beta, x_uncer
91
+
92
+ import gradio as gr
93
+
94
+ title = "BayesCap"
95
+ description = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks (ECCV 2022)"
96
+ article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"
97
+
98
+
99
+ gr.Interface(
100
+ fn=predict,
101
+ inputs=gr.inputs.Image(type='pil', label="Orignal"),
102
+ outputs=[
103
+ gr.outputs.Image(type='pil', label="Low-res"),
104
+ gr.outputs.Image(type='pil', label="Super-res"),
105
+ gr.outputs.Image(type='pil', label="Alpha"),
106
+ gr.outputs.Image(type='pil', label="Beta"),
107
+ gr.outputs.Image(type='pil', label="Uncertainty")
108
+ ],
109
+ title=title,
110
+ description=description,
111
+ article=article,
112
+ examples=[
113
+ ["./demo_examples/baby.png"],
114
+ ["./demo_examples/bird.png"],
115
+ ["./demo_examples/butterfly.png"],
116
+ ["./demo_examples/head.png"],
117
+ ["./demo_examples/woman.png"],
118
+ ]
119
+ ).launch(share=True)
demo_examples/baby.png ADDED
demo_examples/bird.png ADDED
demo_examples/butterfly.png ADDED
demo_examples/head.png ADDED
demo_examples/woman.png ADDED
ds.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import random
4
+ import copy
5
+ import io
6
+ import os
7
+ import numpy as np
8
+ from PIL import Image
9
+ import skimage.transform
10
+ from collections import Counter
11
+
12
+
13
+ import torch
14
+ import torch.utils.data as data
15
+ from torch import Tensor
16
+ from torch.utils.data import Dataset
17
+ from torchvision import transforms
18
+ from torchvision.transforms.functional import InterpolationMode as IMode
19
+
20
+ import utils
21
+
22
+ class ImgDset(Dataset):
23
+ """Customize the data set loading function and prepare low/high resolution image data in advance.
24
+
25
+ Args:
26
+ dataroot (str): Training data set address
27
+ image_size (int): High resolution image size
28
+ upscale_factor (int): Image magnification
29
+ mode (str): Data set loading method, the training data set is for data enhancement,
30
+ and the verification data set is not for data enhancement
31
+
32
+ """
33
+
34
+ def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
35
+ super(ImgDset, self).__init__()
36
+ self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
37
+
38
+ if mode == "train":
39
+ self.hr_transforms = transforms.Compose([
40
+ transforms.RandomCrop(image_size),
41
+ transforms.RandomRotation(90),
42
+ transforms.RandomHorizontalFlip(0.5),
43
+ ])
44
+ else:
45
+ self.hr_transforms = transforms.Resize(image_size)
46
+
47
+ self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
48
+
49
+ def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
50
+ # Read a batch of image data
51
+ image = Image.open(self.filenames[batch_index])
52
+
53
+ # Transform image
54
+ hr_image = self.hr_transforms(image)
55
+ lr_image = self.lr_transforms(hr_image)
56
+
57
+ # Convert image data into Tensor stream format (PyTorch).
58
+ # Note: The range of input and output is between [0, 1]
59
+ lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False)
60
+ hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False)
61
+
62
+ return lr_tensor, hr_tensor
63
+
64
+ def __len__(self) -> int:
65
+ return len(self.filenames)
66
+
67
+
68
+ class PairedImages_w_nameList(Dataset):
69
+ '''
70
+ can act as supervised or un-supervised based on flists
71
+ '''
72
+ def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
73
+ self.flist1 = flist1
74
+ self.flist2 = flist2
75
+ self.transform1 = transform1
76
+ self.transform2 = transform2
77
+ self.do_aug = do_aug
78
+ def __getitem__(self, index):
79
+ impath1 = self.flist1[index]
80
+ img1 = Image.open(impath1).convert('RGB')
81
+ impath2 = self.flist2[index]
82
+ img2 = Image.open(impath2).convert('RGB')
83
+
84
+ img1 = utils.image2tensor(img1, range_norm=False, half=False)
85
+ img2 = utils.image2tensor(img2, range_norm=False, half=False)
86
+
87
+ if self.transform1 is not None:
88
+ img1 = self.transform1(img1)
89
+ if self.transform2 is not None:
90
+ img2 = self.transform2(img2)
91
+
92
+ return img1, img2
93
+ def __len__(self):
94
+ return len(self.flist1)
95
+
96
+ class PairedImages_w_nameList_npy(Dataset):
97
+ '''
98
+ can act as supervised or un-supervised based on flists
99
+ '''
100
+ def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
101
+ self.flist1 = flist1
102
+ self.flist2 = flist2
103
+ self.transform1 = transform1
104
+ self.transform2 = transform2
105
+ self.do_aug = do_aug
106
+ def __getitem__(self, index):
107
+ impath1 = self.flist1[index]
108
+ img1 = np.load(impath1)
109
+ impath2 = self.flist2[index]
110
+ img2 = np.load(impath2)
111
+
112
+ if self.transform1 is not None:
113
+ img1 = self.transform1(img1)
114
+ if self.transform2 is not None:
115
+ img2 = self.transform2(img2)
116
+
117
+ return img1, img2
118
+ def __len__(self):
119
+ return len(self.flist1)
120
+
121
+ # def call_paired():
122
+ # root1='./GOPRO_3840FPS_AVG_3-21/train/blur/'
123
+ # root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/'
124
+
125
+ # flist1=glob.glob(root1+'/*/*.png')
126
+ # flist2=glob.glob(root2+'/*/*.png')
127
+
128
+ # dset = PairedImages_w_nameList(root1,root2,flist1,flist2)
129
+
130
+ #### KITTI depth
131
+
132
+ def load_velodyne_points(filename):
133
+ """Load 3D point cloud from KITTI file format
134
+ (adapted from https://github.com/hunse/kitti)
135
+ """
136
+ points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4)
137
+ points[:, 3] = 1.0 # homogeneous
138
+ return points
139
+
140
+
141
+ def read_calib_file(path):
142
+ """Read KITTI calibration file
143
+ (from https://github.com/hunse/kitti)
144
+ """
145
+ float_chars = set("0123456789.e+- ")
146
+ data = {}
147
+ with open(path, 'r') as f:
148
+ for line in f.readlines():
149
+ key, value = line.split(':', 1)
150
+ value = value.strip()
151
+ data[key] = value
152
+ if float_chars.issuperset(value):
153
+ # try to cast to float array
154
+ try:
155
+ data[key] = np.array(list(map(float, value.split(' '))))
156
+ except ValueError:
157
+ # casting error: data[key] already eq. value, so pass
158
+ pass
159
+
160
+ return data
161
+
162
+
163
+ def sub2ind(matrixSize, rowSub, colSub):
164
+ """Convert row, col matrix subscripts to linear indices
165
+ """
166
+ m, n = matrixSize
167
+ return rowSub * (n-1) + colSub - 1
168
+
169
+
170
+ def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False):
171
+ """Generate a depth map from velodyne data
172
+ """
173
+ # load calibration files
174
+ cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt'))
175
+ velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt'))
176
+ velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis]))
177
+ velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0])))
178
+
179
+ # get image shape
180
+ im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32)
181
+
182
+ # compute projection matrix velodyne->image plane
183
+ R_cam2rect = np.eye(4)
184
+ R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3)
185
+ P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4)
186
+ P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam)
187
+
188
+ # load velodyne points and remove all behind image plane (approximation)
189
+ # each row of the velodyne data is forward, left, up, reflectance
190
+ velo = load_velodyne_points(velo_filename)
191
+ velo = velo[velo[:, 0] >= 0, :]
192
+
193
+ # project the points to the camera
194
+ velo_pts_im = np.dot(P_velo2im, velo.T).T
195
+ velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis]
196
+
197
+ if vel_depth:
198
+ velo_pts_im[:, 2] = velo[:, 0]
199
+
200
+ # check if in bounds
201
+ # use minus 1 to get the exact same value as KITTI matlab code
202
+ velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
203
+ velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
204
+ val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
205
+ val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0])
206
+ velo_pts_im = velo_pts_im[val_inds, :]
207
+
208
+ # project to image
209
+ depth = np.zeros((im_shape[:2]))
210
+ depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2]
211
+
212
+ # find the duplicate points and choose the closest depth
213
+ inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0])
214
+ dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
215
+ for dd in dupe_inds:
216
+ pts = np.where(inds == dd)[0]
217
+ x_loc = int(velo_pts_im[pts[0], 0])
218
+ y_loc = int(velo_pts_im[pts[0], 1])
219
+ depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
220
+ depth[depth < 0] = 0
221
+
222
+ return depth
223
+
224
+ def pil_loader(path):
225
+ # open path as file to avoid ResourceWarning
226
+ # (https://github.com/python-pillow/Pillow/issues/835)
227
+ with open(path, 'rb') as f:
228
+ with Image.open(f) as img:
229
+ return img.convert('RGB')
230
+
231
+
232
+ class MonoDataset(data.Dataset):
233
+ """Superclass for monocular dataloaders
234
+
235
+ Args:
236
+ data_path
237
+ filenames
238
+ height
239
+ width
240
+ frame_idxs
241
+ num_scales
242
+ is_train
243
+ img_ext
244
+ """
245
+ def __init__(self,
246
+ data_path,
247
+ filenames,
248
+ height,
249
+ width,
250
+ frame_idxs,
251
+ num_scales,
252
+ is_train=False,
253
+ img_ext='.jpg'):
254
+ super(MonoDataset, self).__init__()
255
+
256
+ self.data_path = data_path
257
+ self.filenames = filenames
258
+ self.height = height
259
+ self.width = width
260
+ self.num_scales = num_scales
261
+ self.interp = Image.ANTIALIAS
262
+
263
+ self.frame_idxs = frame_idxs
264
+
265
+ self.is_train = is_train
266
+ self.img_ext = img_ext
267
+
268
+ self.loader = pil_loader
269
+ self.to_tensor = transforms.ToTensor()
270
+
271
+ # We need to specify augmentations differently in newer versions of torchvision.
272
+ # We first try the newer tuple version; if this fails we fall back to scalars
273
+ try:
274
+ self.brightness = (0.8, 1.2)
275
+ self.contrast = (0.8, 1.2)
276
+ self.saturation = (0.8, 1.2)
277
+ self.hue = (-0.1, 0.1)
278
+ transforms.ColorJitter.get_params(
279
+ self.brightness, self.contrast, self.saturation, self.hue)
280
+ except TypeError:
281
+ self.brightness = 0.2
282
+ self.contrast = 0.2
283
+ self.saturation = 0.2
284
+ self.hue = 0.1
285
+
286
+ self.resize = {}
287
+ for i in range(self.num_scales):
288
+ s = 2 ** i
289
+ self.resize[i] = transforms.Resize((self.height // s, self.width // s),
290
+ interpolation=self.interp)
291
+
292
+ self.load_depth = self.check_depth()
293
+
294
+ def preprocess(self, inputs, color_aug):
295
+ """Resize colour images to the required scales and augment if required
296
+
297
+ We create the color_aug object in advance and apply the same augmentation to all
298
+ images in this item. This ensures that all images input to the pose network receive the
299
+ same augmentation.
300
+ """
301
+ for k in list(inputs):
302
+ frame = inputs[k]
303
+ if "color" in k:
304
+ n, im, i = k
305
+ for i in range(self.num_scales):
306
+ inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
307
+
308
+ for k in list(inputs):
309
+ f = inputs[k]
310
+ if "color" in k:
311
+ n, im, i = k
312
+ inputs[(n, im, i)] = self.to_tensor(f)
313
+ inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
314
+
315
+ def __len__(self):
316
+ return len(self.filenames)
317
+
318
+ def __getitem__(self, index):
319
+ """Returns a single training item from the dataset as a dictionary.
320
+
321
+ Values correspond to torch tensors.
322
+ Keys in the dictionary are either strings or tuples:
323
+
324
+ ("color", <frame_id>, <scale>) for raw colour images,
325
+ ("color_aug", <frame_id>, <scale>) for augmented colour images,
326
+ ("K", scale) or ("inv_K", scale) for camera intrinsics,
327
+ "stereo_T" for camera extrinsics, and
328
+ "depth_gt" for ground truth depth maps.
329
+
330
+ <frame_id> is either:
331
+ an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
332
+ or
333
+ "s" for the opposite image in the stereo pair.
334
+
335
+ <scale> is an integer representing the scale of the image relative to the fullsize image:
336
+ -1 images at native resolution as loaded from disk
337
+ 0 images resized to (self.width, self.height )
338
+ 1 images resized to (self.width // 2, self.height // 2)
339
+ 2 images resized to (self.width // 4, self.height // 4)
340
+ 3 images resized to (self.width // 8, self.height // 8)
341
+ """
342
+ inputs = {}
343
+
344
+ do_color_aug = self.is_train and random.random() > 0.5
345
+ do_flip = self.is_train and random.random() > 0.5
346
+
347
+ line = self.filenames[index].split()
348
+ folder = line[0]
349
+
350
+ if len(line) == 3:
351
+ frame_index = int(line[1])
352
+ else:
353
+ frame_index = 0
354
+
355
+ if len(line) == 3:
356
+ side = line[2]
357
+ else:
358
+ side = None
359
+
360
+ for i in self.frame_idxs:
361
+ if i == "s":
362
+ other_side = {"r": "l", "l": "r"}[side]
363
+ inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
364
+ else:
365
+ inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
366
+
367
+ # adjusting intrinsics to match each scale in the pyramid
368
+ for scale in range(self.num_scales):
369
+ K = self.K.copy()
370
+
371
+ K[0, :] *= self.width // (2 ** scale)
372
+ K[1, :] *= self.height // (2 ** scale)
373
+
374
+ inv_K = np.linalg.pinv(K)
375
+
376
+ inputs[("K", scale)] = torch.from_numpy(K)
377
+ inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
378
+
379
+ if do_color_aug:
380
+ color_aug = transforms.ColorJitter.get_params(
381
+ self.brightness, self.contrast, self.saturation, self.hue)
382
+ else:
383
+ color_aug = (lambda x: x)
384
+
385
+ self.preprocess(inputs, color_aug)
386
+
387
+ for i in self.frame_idxs:
388
+ del inputs[("color", i, -1)]
389
+ del inputs[("color_aug", i, -1)]
390
+
391
+ if self.load_depth:
392
+ depth_gt = self.get_depth(folder, frame_index, side, do_flip)
393
+ inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
394
+ inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
395
+
396
+ if "s" in self.frame_idxs:
397
+ stereo_T = np.eye(4, dtype=np.float32)
398
+ baseline_sign = -1 if do_flip else 1
399
+ side_sign = -1 if side == "l" else 1
400
+ stereo_T[0, 3] = side_sign * baseline_sign * 0.1
401
+
402
+ inputs["stereo_T"] = torch.from_numpy(stereo_T)
403
+
404
+ return inputs
405
+
406
+ def get_color(self, folder, frame_index, side, do_flip):
407
+ raise NotImplementedError
408
+
409
+ def check_depth(self):
410
+ raise NotImplementedError
411
+
412
+ def get_depth(self, folder, frame_index, side, do_flip):
413
+ raise NotImplementedError
414
+
415
+ class KITTIDataset(MonoDataset):
416
+ """Superclass for different types of KITTI dataset loaders
417
+ """
418
+ def __init__(self, *args, **kwargs):
419
+ super(KITTIDataset, self).__init__(*args, **kwargs)
420
+
421
+ # NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
422
+ # To normalize you need to scale the first row by 1 / image_width and the second row
423
+ # by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
424
+ # If your principal point is far from the center you might need to disable the horizontal
425
+ # flip augmentation.
426
+ self.K = np.array([[0.58, 0, 0.5, 0],
427
+ [0, 1.92, 0.5, 0],
428
+ [0, 0, 1, 0],
429
+ [0, 0, 0, 1]], dtype=np.float32)
430
+
431
+ self.full_res_shape = (1242, 375)
432
+ self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
433
+
434
+ def check_depth(self):
435
+ line = self.filenames[0].split()
436
+ scene_name = line[0]
437
+ frame_index = int(line[1])
438
+
439
+ velo_filename = os.path.join(
440
+ self.data_path,
441
+ scene_name,
442
+ "velodyne_points/data/{:010d}.bin".format(int(frame_index)))
443
+
444
+ return os.path.isfile(velo_filename)
445
+
446
+ def get_color(self, folder, frame_index, side, do_flip):
447
+ color = self.loader(self.get_image_path(folder, frame_index, side))
448
+
449
+ if do_flip:
450
+ color = color.transpose(Image.FLIP_LEFT_RIGHT)
451
+
452
+ return color
453
+
454
+
455
+ class KITTIDepthDataset(KITTIDataset):
456
+ """KITTI dataset which uses the updated ground truth depth maps
457
+ """
458
+ def __init__(self, *args, **kwargs):
459
+ super(KITTIDepthDataset, self).__init__(*args, **kwargs)
460
+
461
+ def get_image_path(self, folder, frame_index, side):
462
+ f_str = "{:010d}{}".format(frame_index, self.img_ext)
463
+ image_path = os.path.join(
464
+ self.data_path,
465
+ folder,
466
+ "image_0{}/data".format(self.side_map[side]),
467
+ f_str)
468
+ return image_path
469
+
470
+ def get_depth(self, folder, frame_index, side, do_flip):
471
+ f_str = "{:010d}.png".format(frame_index)
472
+ depth_path = os.path.join(
473
+ self.data_path,
474
+ folder,
475
+ "proj_depth/groundtruth/image_0{}".format(self.side_map[side]),
476
+ f_str)
477
+
478
+ depth_gt = Image.open(depth_path)
479
+ depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST)
480
+ depth_gt = np.array(depth_gt).astype(np.float32) / 256
481
+
482
+ if do_flip:
483
+ depth_gt = np.fliplr(depth_gt)
484
+
485
+ return depth_gt
losses.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from torch import Tensor
6
+
7
+ class ContentLoss(nn.Module):
8
+ """Constructs a content loss function based on the VGG19 network.
9
+ Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
10
+
11
+ Paper reference list:
12
+ -`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
13
+ -`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
14
+ -`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
15
+
16
+ """
17
+
18
+ def __init__(self) -> None:
19
+ super(ContentLoss, self).__init__()
20
+ # Load the VGG19 model trained on the ImageNet dataset.
21
+ vgg19 = models.vgg19(pretrained=True).eval()
22
+ # Extract the thirty-sixth layer output in the VGG19 model as the content loss.
23
+ self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
24
+ # Freeze model parameters.
25
+ for parameters in self.feature_extractor.parameters():
26
+ parameters.requires_grad = False
27
+
28
+ # The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
29
+ self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
30
+ self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
31
+
32
+ def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
33
+ # Standardized operations
34
+ sr = sr.sub(self.mean).div(self.std)
35
+ hr = hr.sub(self.mean).div(self.std)
36
+
37
+ # Find the feature map difference between the two images
38
+ loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
39
+
40
+ return loss
41
+
42
+
43
+ class GenGaussLoss(nn.Module):
44
+ def __init__(
45
+ self, reduction='mean',
46
+ alpha_eps = 1e-4, beta_eps=1e-4,
47
+ resi_min = 1e-4, resi_max=1e3
48
+ ) -> None:
49
+ super(GenGaussLoss, self).__init__()
50
+ self.reduction = reduction
51
+ self.alpha_eps = alpha_eps
52
+ self.beta_eps = beta_eps
53
+ self.resi_min = resi_min
54
+ self.resi_max = resi_max
55
+
56
+ def forward(
57
+ self,
58
+ mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
59
+ ):
60
+ one_over_alpha1 = one_over_alpha + self.alpha_eps
61
+ beta1 = beta + self.beta_eps
62
+
63
+ resi = torch.abs(mean - target)
64
+ # resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
65
+ resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
66
+ ## check if resi has nans
67
+ if torch.sum(resi != resi) > 0:
68
+ print('resi has nans!!')
69
+ return None
70
+
71
+ log_one_over_alpha = torch.log(one_over_alpha1)
72
+ log_beta = torch.log(beta1)
73
+ lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
74
+
75
+ if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
76
+ print('log_one_over_alpha has nan')
77
+ if torch.sum(lgamma_beta != lgamma_beta) > 0:
78
+ print('lgamma_beta has nan')
79
+ if torch.sum(log_beta != log_beta) > 0:
80
+ print('log_beta has nan')
81
+
82
+ l = resi - log_one_over_alpha + lgamma_beta - log_beta
83
+
84
+ if self.reduction == 'mean':
85
+ return l.mean()
86
+ elif self.reduction == 'sum':
87
+ return l.sum()
88
+ else:
89
+ print('Reduction not supported')
90
+ return None
91
+
92
+ class TempCombLoss(nn.Module):
93
+ def __init__(
94
+ self, reduction='mean',
95
+ alpha_eps = 1e-4, beta_eps=1e-4,
96
+ resi_min = 1e-4, resi_max=1e3
97
+ ) -> None:
98
+ super(TempCombLoss, self).__init__()
99
+ self.reduction = reduction
100
+ self.alpha_eps = alpha_eps
101
+ self.beta_eps = beta_eps
102
+ self.resi_min = resi_min
103
+ self.resi_max = resi_max
104
+
105
+ self.L_GenGauss = GenGaussLoss(
106
+ reduction=self.reduction,
107
+ alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
108
+ resi_min=self.resi_min, resi_max=self.resi_max
109
+ )
110
+ self.L_l1 = nn.L1Loss(reduction=self.reduction)
111
+
112
+ def forward(
113
+ self,
114
+ mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
115
+ T1: float, T2: float
116
+ ):
117
+ l1 = self.L_l1(mean, target)
118
+ l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
119
+ l = T1*l1 + T2*l2
120
+
121
+ return l
122
+
123
+
124
+ # x1 = torch.randn(4,3,32,32)
125
+ # x2 = torch.rand(4,3,32,32)
126
+ # x3 = torch.rand(4,3,32,32)
127
+ # x4 = torch.randn(4,3,32,32)
128
+
129
+ # L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
130
+ # L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
131
+ # print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2))
networks_SRGAN.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from torch import Tensor
6
+
7
+ # __all__ = [
8
+ # "ResidualConvBlock",
9
+ # "Discriminator", "Generator",
10
+ # ]
11
+
12
+
13
+ class ResidualConvBlock(nn.Module):
14
+ """Implements residual conv function.
15
+
16
+ Args:
17
+ channels (int): Number of channels in the input image.
18
+ """
19
+
20
+ def __init__(self, channels: int) -> None:
21
+ super(ResidualConvBlock, self).__init__()
22
+ self.rcb = nn.Sequential(
23
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
24
+ nn.BatchNorm2d(channels),
25
+ nn.PReLU(),
26
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
27
+ nn.BatchNorm2d(channels),
28
+ )
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ identity = x
32
+
33
+ out = self.rcb(x)
34
+ out = torch.add(out, identity)
35
+
36
+ return out
37
+
38
+
39
+ class Discriminator(nn.Module):
40
+ def __init__(self) -> None:
41
+ super(Discriminator, self).__init__()
42
+ self.features = nn.Sequential(
43
+ # input size. (3) x 96 x 96
44
+ nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False),
45
+ nn.LeakyReLU(0.2, True),
46
+ # state size. (64) x 48 x 48
47
+ nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
48
+ nn.BatchNorm2d(64),
49
+ nn.LeakyReLU(0.2, True),
50
+ nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
51
+ nn.BatchNorm2d(128),
52
+ nn.LeakyReLU(0.2, True),
53
+ # state size. (128) x 24 x 24
54
+ nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
55
+ nn.BatchNorm2d(128),
56
+ nn.LeakyReLU(0.2, True),
57
+ nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
58
+ nn.BatchNorm2d(256),
59
+ nn.LeakyReLU(0.2, True),
60
+ # state size. (256) x 12 x 12
61
+ nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
62
+ nn.BatchNorm2d(256),
63
+ nn.LeakyReLU(0.2, True),
64
+ nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
65
+ nn.BatchNorm2d(512),
66
+ nn.LeakyReLU(0.2, True),
67
+ # state size. (512) x 6 x 6
68
+ nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
69
+ nn.BatchNorm2d(512),
70
+ nn.LeakyReLU(0.2, True),
71
+ )
72
+
73
+ self.classifier = nn.Sequential(
74
+ nn.Linear(512 * 6 * 6, 1024),
75
+ nn.LeakyReLU(0.2, True),
76
+ nn.Linear(1024, 1),
77
+ )
78
+
79
+ def forward(self, x: Tensor) -> Tensor:
80
+ out = self.features(x)
81
+ out = torch.flatten(out, 1)
82
+ out = self.classifier(out)
83
+
84
+ return out
85
+
86
+
87
+ class Generator(nn.Module):
88
+ def __init__(self) -> None:
89
+ super(Generator, self).__init__()
90
+ # First conv layer.
91
+ self.conv_block1 = nn.Sequential(
92
+ nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)),
93
+ nn.PReLU(),
94
+ )
95
+
96
+ # Features trunk blocks.
97
+ trunk = []
98
+ for _ in range(16):
99
+ trunk.append(ResidualConvBlock(64))
100
+ self.trunk = nn.Sequential(*trunk)
101
+
102
+ # Second conv layer.
103
+ self.conv_block2 = nn.Sequential(
104
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
105
+ nn.BatchNorm2d(64),
106
+ )
107
+
108
+ # Upscale conv block.
109
+ self.upsampling = nn.Sequential(
110
+ nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
111
+ nn.PixelShuffle(2),
112
+ nn.PReLU(),
113
+ nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
114
+ nn.PixelShuffle(2),
115
+ nn.PReLU(),
116
+ )
117
+
118
+ # Output layer.
119
+ self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4))
120
+
121
+ # Initialize neural network weights.
122
+ self._initialize_weights()
123
+
124
+ def forward(self, x: Tensor, dop=None) -> Tensor:
125
+ if not dop:
126
+ return self._forward_impl(x)
127
+ else:
128
+ return self._forward_w_dop_impl(x, dop)
129
+
130
+ # Support torch.script function.
131
+ def _forward_impl(self, x: Tensor) -> Tensor:
132
+ out1 = self.conv_block1(x)
133
+ out = self.trunk(out1)
134
+ out2 = self.conv_block2(out)
135
+ out = torch.add(out1, out2)
136
+ out = self.upsampling(out)
137
+ out = self.conv_block3(out)
138
+
139
+ return out
140
+
141
+ def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor:
142
+ out1 = self.conv_block1(x)
143
+ out = self.trunk(out1)
144
+ out2 = F.dropout2d(self.conv_block2(out), p=dop)
145
+ out = torch.add(out1, out2)
146
+ out = self.upsampling(out)
147
+ out = self.conv_block3(out)
148
+
149
+ return out
150
+
151
+ def _initialize_weights(self) -> None:
152
+ for module in self.modules():
153
+ if isinstance(module, nn.Conv2d):
154
+ nn.init.kaiming_normal_(module.weight)
155
+ if module.bias is not None:
156
+ nn.init.constant_(module.bias, 0)
157
+ elif isinstance(module, nn.BatchNorm2d):
158
+ nn.init.constant_(module.weight, 1)
159
+
160
+
161
+ #### BayesCap
162
+ class BayesCap(nn.Module):
163
+ def __init__(self, in_channels=3, out_channels=3) -> None:
164
+ super(BayesCap, self).__init__()
165
+ # First conv layer.
166
+ self.conv_block1 = nn.Sequential(
167
+ nn.Conv2d(
168
+ in_channels, 64,
169
+ kernel_size=9, stride=1, padding=4
170
+ ),
171
+ nn.PReLU(),
172
+ )
173
+
174
+ # Features trunk blocks.
175
+ trunk = []
176
+ for _ in range(16):
177
+ trunk.append(ResidualConvBlock(64))
178
+ self.trunk = nn.Sequential(*trunk)
179
+
180
+ # Second conv layer.
181
+ self.conv_block2 = nn.Sequential(
182
+ nn.Conv2d(
183
+ 64, 64,
184
+ kernel_size=3, stride=1, padding=1, bias=False
185
+ ),
186
+ nn.BatchNorm2d(64),
187
+ )
188
+
189
+ # Output layer.
190
+ self.conv_block3_mu = nn.Conv2d(
191
+ 64, out_channels=out_channels,
192
+ kernel_size=9, stride=1, padding=4
193
+ )
194
+ self.conv_block3_alpha = nn.Sequential(
195
+ nn.Conv2d(
196
+ 64, 64,
197
+ kernel_size=9, stride=1, padding=4
198
+ ),
199
+ nn.PReLU(),
200
+ nn.Conv2d(
201
+ 64, 64,
202
+ kernel_size=9, stride=1, padding=4
203
+ ),
204
+ nn.PReLU(),
205
+ nn.Conv2d(
206
+ 64, 1,
207
+ kernel_size=9, stride=1, padding=4
208
+ ),
209
+ nn.ReLU(),
210
+ )
211
+ self.conv_block3_beta = nn.Sequential(
212
+ nn.Conv2d(
213
+ 64, 64,
214
+ kernel_size=9, stride=1, padding=4
215
+ ),
216
+ nn.PReLU(),
217
+ nn.Conv2d(
218
+ 64, 64,
219
+ kernel_size=9, stride=1, padding=4
220
+ ),
221
+ nn.PReLU(),
222
+ nn.Conv2d(
223
+ 64, 1,
224
+ kernel_size=9, stride=1, padding=4
225
+ ),
226
+ nn.ReLU(),
227
+ )
228
+
229
+ # Initialize neural network weights.
230
+ self._initialize_weights()
231
+
232
+ def forward(self, x: Tensor) -> Tensor:
233
+ return self._forward_impl(x)
234
+
235
+ # Support torch.script function.
236
+ def _forward_impl(self, x: Tensor) -> Tensor:
237
+ out1 = self.conv_block1(x)
238
+ out = self.trunk(out1)
239
+ out2 = self.conv_block2(out)
240
+ out = out1 + out2
241
+ out_mu = self.conv_block3_mu(out)
242
+ out_alpha = self.conv_block3_alpha(out)
243
+ out_beta = self.conv_block3_beta(out)
244
+ return out_mu, out_alpha, out_beta
245
+
246
+ def _initialize_weights(self) -> None:
247
+ for module in self.modules():
248
+ if isinstance(module, nn.Conv2d):
249
+ nn.init.kaiming_normal_(module.weight)
250
+ if module.bias is not None:
251
+ nn.init.constant_(module.bias, 0)
252
+ elif isinstance(module, nn.BatchNorm2d):
253
+ nn.init.constant_(module.weight, 1)
254
+
255
+
256
+ class BayesCap_noID(nn.Module):
257
+ def __init__(self, in_channels=3, out_channels=3) -> None:
258
+ super(BayesCap_noID, self).__init__()
259
+ # First conv layer.
260
+ self.conv_block1 = nn.Sequential(
261
+ nn.Conv2d(
262
+ in_channels, 64,
263
+ kernel_size=9, stride=1, padding=4
264
+ ),
265
+ nn.PReLU(),
266
+ )
267
+
268
+ # Features trunk blocks.
269
+ trunk = []
270
+ for _ in range(16):
271
+ trunk.append(ResidualConvBlock(64))
272
+ self.trunk = nn.Sequential(*trunk)
273
+
274
+ # Second conv layer.
275
+ self.conv_block2 = nn.Sequential(
276
+ nn.Conv2d(
277
+ 64, 64,
278
+ kernel_size=3, stride=1, padding=1, bias=False
279
+ ),
280
+ nn.BatchNorm2d(64),
281
+ )
282
+
283
+ # Output layer.
284
+ # self.conv_block3_mu = nn.Conv2d(
285
+ # 64, out_channels=out_channels,
286
+ # kernel_size=9, stride=1, padding=4
287
+ # )
288
+ self.conv_block3_alpha = nn.Sequential(
289
+ nn.Conv2d(
290
+ 64, 64,
291
+ kernel_size=9, stride=1, padding=4
292
+ ),
293
+ nn.PReLU(),
294
+ nn.Conv2d(
295
+ 64, 64,
296
+ kernel_size=9, stride=1, padding=4
297
+ ),
298
+ nn.PReLU(),
299
+ nn.Conv2d(
300
+ 64, 1,
301
+ kernel_size=9, stride=1, padding=4
302
+ ),
303
+ nn.ReLU(),
304
+ )
305
+ self.conv_block3_beta = nn.Sequential(
306
+ nn.Conv2d(
307
+ 64, 64,
308
+ kernel_size=9, stride=1, padding=4
309
+ ),
310
+ nn.PReLU(),
311
+ nn.Conv2d(
312
+ 64, 64,
313
+ kernel_size=9, stride=1, padding=4
314
+ ),
315
+ nn.PReLU(),
316
+ nn.Conv2d(
317
+ 64, 1,
318
+ kernel_size=9, stride=1, padding=4
319
+ ),
320
+ nn.ReLU(),
321
+ )
322
+
323
+ # Initialize neural network weights.
324
+ self._initialize_weights()
325
+
326
+ def forward(self, x: Tensor) -> Tensor:
327
+ return self._forward_impl(x)
328
+
329
+ # Support torch.script function.
330
+ def _forward_impl(self, x: Tensor) -> Tensor:
331
+ out1 = self.conv_block1(x)
332
+ out = self.trunk(out1)
333
+ out2 = self.conv_block2(out)
334
+ out = out1 + out2
335
+ # out_mu = self.conv_block3_mu(out)
336
+ out_alpha = self.conv_block3_alpha(out)
337
+ out_beta = self.conv_block3_beta(out)
338
+ return out_alpha, out_beta
339
+
340
+ def _initialize_weights(self) -> None:
341
+ for module in self.modules():
342
+ if isinstance(module, nn.Conv2d):
343
+ nn.init.kaiming_normal_(module.weight)
344
+ if module.bias is not None:
345
+ nn.init.constant_(module.bias, 0)
346
+ elif isinstance(module, nn.BatchNorm2d):
347
+ nn.init.constant_(module.weight, 1)
networks_T1toT2.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import functools
5
+
6
+ ### components
7
+ class ResConv(nn.Module):
8
+ """
9
+ Residual convolutional block, where
10
+ convolutional block consists: (convolution => [BN] => ReLU) * 3
11
+ residual connection adds the input to the output
12
+ """
13
+ def __init__(self, in_channels, out_channels, mid_channels=None):
14
+ super().__init__()
15
+ if not mid_channels:
16
+ mid_channels = out_channels
17
+ self.double_conv = nn.Sequential(
18
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
19
+ nn.BatchNorm2d(mid_channels),
20
+ nn.ReLU(inplace=True),
21
+ nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
22
+ nn.BatchNorm2d(mid_channels),
23
+ nn.ReLU(inplace=True),
24
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
25
+ nn.BatchNorm2d(out_channels),
26
+ nn.ReLU(inplace=True)
27
+ )
28
+ self.double_conv1 = nn.Sequential(
29
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
30
+ nn.BatchNorm2d(out_channels),
31
+ nn.ReLU(inplace=True),
32
+ )
33
+ def forward(self, x):
34
+ x_in = self.double_conv1(x)
35
+ x1 = self.double_conv(x)
36
+ return self.double_conv(x) + x_in
37
+
38
+ class Down(nn.Module):
39
+ """Downscaling with maxpool then Resconv"""
40
+ def __init__(self, in_channels, out_channels):
41
+ super().__init__()
42
+ self.maxpool_conv = nn.Sequential(
43
+ nn.MaxPool2d(2),
44
+ ResConv(in_channels, out_channels)
45
+ )
46
+ def forward(self, x):
47
+ return self.maxpool_conv(x)
48
+
49
+ class Up(nn.Module):
50
+ """Upscaling then double conv"""
51
+ def __init__(self, in_channels, out_channels, bilinear=True):
52
+ super().__init__()
53
+ # if bilinear, use the normal convolutions to reduce the number of channels
54
+ if bilinear:
55
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
56
+ self.conv = ResConv(in_channels, out_channels, in_channels // 2)
57
+ else:
58
+ self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
59
+ self.conv = ResConv(in_channels, out_channels)
60
+ def forward(self, x1, x2):
61
+ x1 = self.up(x1)
62
+ # input is CHW
63
+ diffY = x2.size()[2] - x1.size()[2]
64
+ diffX = x2.size()[3] - x1.size()[3]
65
+ x1 = F.pad(
66
+ x1,
67
+ [
68
+ diffX // 2, diffX - diffX // 2,
69
+ diffY // 2, diffY - diffY // 2
70
+ ]
71
+ )
72
+ # if you have padding issues, see
73
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
74
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
75
+ x = torch.cat([x2, x1], dim=1)
76
+ return self.conv(x)
77
+
78
+ class OutConv(nn.Module):
79
+ def __init__(self, in_channels, out_channels):
80
+ super(OutConv, self).__init__()
81
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
82
+ def forward(self, x):
83
+ # return F.relu(self.conv(x))
84
+ return self.conv(x)
85
+
86
+ ##### The composite networks
87
+ class UNet(nn.Module):
88
+ def __init__(self, n_channels, out_channels, bilinear=True):
89
+ super(UNet, self).__init__()
90
+ self.n_channels = n_channels
91
+ self.out_channels = out_channels
92
+ self.bilinear = bilinear
93
+ ####
94
+ self.inc = ResConv(n_channels, 64)
95
+ self.down1 = Down(64, 128)
96
+ self.down2 = Down(128, 256)
97
+ self.down3 = Down(256, 512)
98
+ factor = 2 if bilinear else 1
99
+ self.down4 = Down(512, 1024 // factor)
100
+ self.up1 = Up(1024, 512 // factor, bilinear)
101
+ self.up2 = Up(512, 256 // factor, bilinear)
102
+ self.up3 = Up(256, 128 // factor, bilinear)
103
+ self.up4 = Up(128, 64, bilinear)
104
+ self.outc = OutConv(64, out_channels)
105
+ def forward(self, x):
106
+ x1 = self.inc(x)
107
+ x2 = self.down1(x1)
108
+ x3 = self.down2(x2)
109
+ x4 = self.down3(x3)
110
+ x5 = self.down4(x4)
111
+ x = self.up1(x5, x4)
112
+ x = self.up2(x, x3)
113
+ x = self.up3(x, x2)
114
+ x = self.up4(x, x1)
115
+ y = self.outc(x)
116
+ return y
117
+
118
+ class CasUNet(nn.Module):
119
+ def __init__(self, n_unet, io_channels, bilinear=True):
120
+ super(CasUNet, self).__init__()
121
+ self.n_unet = n_unet
122
+ self.io_channels = io_channels
123
+ self.bilinear = bilinear
124
+ ####
125
+ self.unet_list = nn.ModuleList()
126
+ for i in range(self.n_unet):
127
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
128
+ def forward(self, x, dop=None):
129
+ y = x
130
+ for i in range(self.n_unet):
131
+ if i==0:
132
+ if dop is not None:
133
+ y = F.dropout2d(self.unet_list[i](y), p=dop)
134
+ else:
135
+ y = self.unet_list[i](y)
136
+ else:
137
+ y = self.unet_list[i](y+x)
138
+ return y
139
+
140
+ class CasUNet_2head(nn.Module):
141
+ def __init__(self, n_unet, io_channels, bilinear=True):
142
+ super(CasUNet_2head, self).__init__()
143
+ self.n_unet = n_unet
144
+ self.io_channels = io_channels
145
+ self.bilinear = bilinear
146
+ ####
147
+ self.unet_list = nn.ModuleList()
148
+ for i in range(self.n_unet):
149
+ if i != self.n_unet-1:
150
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
151
+ else:
152
+ self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear))
153
+ def forward(self, x):
154
+ y = x
155
+ for i in range(self.n_unet):
156
+ if i==0:
157
+ y = self.unet_list[i](y)
158
+ else:
159
+ y = self.unet_list[i](y+x)
160
+ y_mean, y_sigma = y[0], y[1]
161
+ return y_mean, y_sigma
162
+
163
+ class CasUNet_3head(nn.Module):
164
+ def __init__(self, n_unet, io_channels, bilinear=True):
165
+ super(CasUNet_3head, self).__init__()
166
+ self.n_unet = n_unet
167
+ self.io_channels = io_channels
168
+ self.bilinear = bilinear
169
+ ####
170
+ self.unet_list = nn.ModuleList()
171
+ for i in range(self.n_unet):
172
+ if i != self.n_unet-1:
173
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
174
+ else:
175
+ self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear))
176
+ def forward(self, x):
177
+ y = x
178
+ for i in range(self.n_unet):
179
+ if i==0:
180
+ y = self.unet_list[i](y)
181
+ else:
182
+ y = self.unet_list[i](y+x)
183
+ y_mean, y_alpha, y_beta = y[0], y[1], y[2]
184
+ return y_mean, y_alpha, y_beta
185
+
186
+ class UNet_2head(nn.Module):
187
+ def __init__(self, n_channels, out_channels, bilinear=True):
188
+ super(UNet_2head, self).__init__()
189
+ self.n_channels = n_channels
190
+ self.out_channels = out_channels
191
+ self.bilinear = bilinear
192
+ ####
193
+ self.inc = ResConv(n_channels, 64)
194
+ self.down1 = Down(64, 128)
195
+ self.down2 = Down(128, 256)
196
+ self.down3 = Down(256, 512)
197
+ factor = 2 if bilinear else 1
198
+ self.down4 = Down(512, 1024 // factor)
199
+ self.up1 = Up(1024, 512 // factor, bilinear)
200
+ self.up2 = Up(512, 256 // factor, bilinear)
201
+ self.up3 = Up(256, 128 // factor, bilinear)
202
+ self.up4 = Up(128, 64, bilinear)
203
+ #per pixel multiple channels may exist
204
+ self.out_mean = OutConv(64, out_channels)
205
+ #variance will always be a single number for a pixel
206
+ self.out_var = nn.Sequential(
207
+ OutConv(64, 128),
208
+ OutConv(128, 1),
209
+ )
210
+ def forward(self, x):
211
+ x1 = self.inc(x)
212
+ x2 = self.down1(x1)
213
+ x3 = self.down2(x2)
214
+ x4 = self.down3(x3)
215
+ x5 = self.down4(x4)
216
+ x = self.up1(x5, x4)
217
+ x = self.up2(x, x3)
218
+ x = self.up3(x, x2)
219
+ x = self.up4(x, x1)
220
+ y_mean, y_var = self.out_mean(x), self.out_var(x)
221
+ return y_mean, y_var
222
+
223
+ class UNet_3head(nn.Module):
224
+ def __init__(self, n_channels, out_channels, bilinear=True):
225
+ super(UNet_3head, self).__init__()
226
+ self.n_channels = n_channels
227
+ self.out_channels = out_channels
228
+ self.bilinear = bilinear
229
+ ####
230
+ self.inc = ResConv(n_channels, 64)
231
+ self.down1 = Down(64, 128)
232
+ self.down2 = Down(128, 256)
233
+ self.down3 = Down(256, 512)
234
+ factor = 2 if bilinear else 1
235
+ self.down4 = Down(512, 1024 // factor)
236
+ self.up1 = Up(1024, 512 // factor, bilinear)
237
+ self.up2 = Up(512, 256 // factor, bilinear)
238
+ self.up3 = Up(256, 128 // factor, bilinear)
239
+ self.up4 = Up(128, 64, bilinear)
240
+ #per pixel multiple channels may exist
241
+ self.out_mean = OutConv(64, out_channels)
242
+ #variance will always be a single number for a pixel
243
+ self.out_alpha = nn.Sequential(
244
+ OutConv(64, 128),
245
+ OutConv(128, 1),
246
+ nn.ReLU()
247
+ )
248
+ self.out_beta = nn.Sequential(
249
+ OutConv(64, 128),
250
+ OutConv(128, 1),
251
+ nn.ReLU()
252
+ )
253
+ def forward(self, x):
254
+ x1 = self.inc(x)
255
+ x2 = self.down1(x1)
256
+ x3 = self.down2(x2)
257
+ x4 = self.down3(x3)
258
+ x5 = self.down4(x4)
259
+ x = self.up1(x5, x4)
260
+ x = self.up2(x, x3)
261
+ x = self.up3(x, x2)
262
+ x = self.up4(x, x1)
263
+ y_mean, y_alpha, y_beta = self.out_mean(x), \
264
+ self.out_alpha(x), self.out_beta(x)
265
+ return y_mean, y_alpha, y_beta
266
+
267
+ class ResidualBlock(nn.Module):
268
+ def __init__(self, in_features):
269
+ super(ResidualBlock, self).__init__()
270
+ conv_block = [
271
+ nn.ReflectionPad2d(1),
272
+ nn.Conv2d(in_features, in_features, 3),
273
+ nn.InstanceNorm2d(in_features),
274
+ nn.ReLU(inplace=True),
275
+ nn.ReflectionPad2d(1),
276
+ nn.Conv2d(in_features, in_features, 3),
277
+ nn.InstanceNorm2d(in_features)
278
+ ]
279
+ self.conv_block = nn.Sequential(*conv_block)
280
+ def forward(self, x):
281
+ return x + self.conv_block(x)
282
+
283
+ class Generator(nn.Module):
284
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9):
285
+ super(Generator, self).__init__()
286
+ # Initial convolution block
287
+ model = [
288
+ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7),
289
+ nn.InstanceNorm2d(64), nn.ReLU(inplace=True)
290
+ ]
291
+ # Downsampling
292
+ in_features = 64
293
+ out_features = in_features*2
294
+ for _ in range(2):
295
+ model += [
296
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
297
+ nn.InstanceNorm2d(out_features),
298
+ nn.ReLU(inplace=True)
299
+ ]
300
+ in_features = out_features
301
+ out_features = in_features*2
302
+ # Residual blocks
303
+ for _ in range(n_residual_blocks):
304
+ model += [ResidualBlock(in_features)]
305
+ # Upsampling
306
+ out_features = in_features//2
307
+ for _ in range(2):
308
+ model += [
309
+ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
310
+ nn.InstanceNorm2d(out_features),
311
+ nn.ReLU(inplace=True)
312
+ ]
313
+ in_features = out_features
314
+ out_features = in_features//2
315
+ # Output layer
316
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]
317
+ self.model = nn.Sequential(*model)
318
+ def forward(self, x):
319
+ return self.model(x)
320
+
321
+
322
+ class ResnetGenerator(nn.Module):
323
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
324
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
325
+ """
326
+
327
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
328
+ """Construct a Resnet-based generator
329
+ Parameters:
330
+ input_nc (int) -- the number of channels in input images
331
+ output_nc (int) -- the number of channels in output images
332
+ ngf (int) -- the number of filters in the last conv layer
333
+ norm_layer -- normalization layer
334
+ use_dropout (bool) -- if use dropout layers
335
+ n_blocks (int) -- the number of ResNet blocks
336
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
337
+ """
338
+ assert(n_blocks >= 0)
339
+ super(ResnetGenerator, self).__init__()
340
+ if type(norm_layer) == functools.partial:
341
+ use_bias = norm_layer.func == nn.InstanceNorm2d
342
+ else:
343
+ use_bias = norm_layer == nn.InstanceNorm2d
344
+
345
+ model = [nn.ReflectionPad2d(3),
346
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
347
+ norm_layer(ngf),
348
+ nn.ReLU(True)]
349
+
350
+ n_downsampling = 2
351
+ for i in range(n_downsampling): # add downsampling layers
352
+ mult = 2 ** i
353
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
354
+ norm_layer(ngf * mult * 2),
355
+ nn.ReLU(True)]
356
+
357
+ mult = 2 ** n_downsampling
358
+ for i in range(n_blocks): # add ResNet blocks
359
+
360
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
361
+
362
+ for i in range(n_downsampling): # add upsampling layers
363
+ mult = 2 ** (n_downsampling - i)
364
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
365
+ kernel_size=3, stride=2,
366
+ padding=1, output_padding=1,
367
+ bias=use_bias),
368
+ norm_layer(int(ngf * mult / 2)),
369
+ nn.ReLU(True)]
370
+ model += [nn.ReflectionPad2d(3)]
371
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
372
+ model += [nn.Tanh()]
373
+
374
+ self.model = nn.Sequential(*model)
375
+
376
+ def forward(self, input):
377
+ """Standard forward"""
378
+ return self.model(input)
379
+
380
+
381
+ class ResnetBlock(nn.Module):
382
+ """Define a Resnet block"""
383
+
384
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
385
+ """Initialize the Resnet block
386
+ A resnet block is a conv block with skip connections
387
+ We construct a conv block with build_conv_block function,
388
+ and implement skip connections in <forward> function.
389
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
390
+ """
391
+ super(ResnetBlock, self).__init__()
392
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
393
+
394
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
395
+ """Construct a convolutional block.
396
+ Parameters:
397
+ dim (int) -- the number of channels in the conv layer.
398
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
399
+ norm_layer -- normalization layer
400
+ use_dropout (bool) -- if use dropout layers.
401
+ use_bias (bool) -- if the conv layer uses bias or not
402
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
403
+ """
404
+ conv_block = []
405
+ p = 0
406
+ if padding_type == 'reflect':
407
+ conv_block += [nn.ReflectionPad2d(1)]
408
+ elif padding_type == 'replicate':
409
+ conv_block += [nn.ReplicationPad2d(1)]
410
+ elif padding_type == 'zero':
411
+ p = 1
412
+ else:
413
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
414
+
415
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
416
+ if use_dropout:
417
+ conv_block += [nn.Dropout(0.5)]
418
+
419
+ p = 0
420
+ if padding_type == 'reflect':
421
+ conv_block += [nn.ReflectionPad2d(1)]
422
+ elif padding_type == 'replicate':
423
+ conv_block += [nn.ReplicationPad2d(1)]
424
+ elif padding_type == 'zero':
425
+ p = 1
426
+ else:
427
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
428
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
429
+
430
+ return nn.Sequential(*conv_block)
431
+
432
+ def forward(self, x):
433
+ """Forward function (with skip connections)"""
434
+ out = x + self.conv_block(x) # add skip connections
435
+ return out
436
+
437
+ ### discriminator
438
+ class NLayerDiscriminator(nn.Module):
439
+ """Defines a PatchGAN discriminator"""
440
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
441
+ """Construct a PatchGAN discriminator
442
+ Parameters:
443
+ input_nc (int) -- the number of channels in input images
444
+ ndf (int) -- the number of filters in the last conv layer
445
+ n_layers (int) -- the number of conv layers in the discriminator
446
+ norm_layer -- normalization layer
447
+ """
448
+ super(NLayerDiscriminator, self).__init__()
449
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
450
+ use_bias = norm_layer.func == nn.InstanceNorm2d
451
+ else:
452
+ use_bias = norm_layer == nn.InstanceNorm2d
453
+ kw = 4
454
+ padw = 1
455
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
456
+ nf_mult = 1
457
+ nf_mult_prev = 1
458
+ for n in range(1, n_layers): # gradually increase the number of filters
459
+ nf_mult_prev = nf_mult
460
+ nf_mult = min(2 ** n, 8)
461
+ sequence += [
462
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
463
+ norm_layer(ndf * nf_mult),
464
+ nn.LeakyReLU(0.2, True)
465
+ ]
466
+ nf_mult_prev = nf_mult
467
+ nf_mult = min(2 ** n_layers, 8)
468
+ sequence += [
469
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
470
+ norm_layer(ndf * nf_mult),
471
+ nn.LeakyReLU(0.2, True)
472
+ ]
473
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
474
+ self.model = nn.Sequential(*sequence)
475
+ def forward(self, input):
476
+ """Standard forward."""
477
+ return self.model(input)
requirements.txt ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may be used to create an environment using:
2
+ # $ conda create --name <env> --file <this file>
3
+ # platform: linux-64
4
+ _libgcc_mutex=0.1=conda_forge
5
+ _openmp_mutex=4.5=2_kmp_llvm
6
+ aiohttp=3.8.1=pypi_0
7
+ aiosignal=1.2.0=pypi_0
8
+ albumentations=1.2.0=pyhd8ed1ab_0
9
+ alsa-lib=1.2.6.1=h7f98852_0
10
+ analytics-python=1.4.0=pypi_0
11
+ anyio=3.6.1=pypi_0
12
+ aom=3.3.0=h27087fc_1
13
+ argon2-cffi=21.3.0=pypi_0
14
+ argon2-cffi-bindings=21.2.0=pypi_0
15
+ asttokens=2.0.5=pypi_0
16
+ async-timeout=4.0.2=pypi_0
17
+ attr=2.5.1=h166bdaf_0
18
+ attrs=21.4.0=pypi_0
19
+ babel=2.10.1=pypi_0
20
+ backcall=0.2.0=pypi_0
21
+ backoff=1.10.0=pypi_0
22
+ bcrypt=3.2.2=pypi_0
23
+ beautifulsoup4=4.11.1=pypi_0
24
+ blas=1.0=mkl
25
+ bleach=5.0.0=pypi_0
26
+ blosc=1.21.1=h83bc5f7_3
27
+ brotli=1.0.9=h166bdaf_7
28
+ brotli-bin=1.0.9=h166bdaf_7
29
+ brotlipy=0.7.0=py310h7f8727e_1002
30
+ brunsli=0.1=h9c3ff4c_0
31
+ bzip2=1.0.8=h7b6447c_0
32
+ c-ares=1.18.1=h7f98852_0
33
+ c-blosc2=2.2.0=h7a311fb_0
34
+ ca-certificates=2022.6.15=ha878542_0
35
+ cairo=1.16.0=ha61ee94_1011
36
+ certifi=2022.6.15=py310hff52083_0
37
+ cffi=1.15.0=py310hd667e15_1
38
+ cfitsio=4.1.0=hd9d235c_0
39
+ charls=2.3.4=h9c3ff4c_0
40
+ charset-normalizer=2.0.4=pyhd3eb1b0_0
41
+ click=8.1.3=pypi_0
42
+ cloudpickle=2.1.0=pyhd8ed1ab_0
43
+ cryptography=37.0.1=py310h9ce1e76_0
44
+ cudatoolkit=10.2.89=hfd86e86_1
45
+ cycler=0.11.0=pypi_0
46
+ cytoolz=0.11.2=py310h5764c6d_2
47
+ dask-core=2022.7.0=pyhd8ed1ab_0
48
+ dbus=1.13.6=h5008d03_3
49
+ debugpy=1.6.0=pypi_0
50
+ decorator=5.1.1=pypi_0
51
+ defusedxml=0.7.1=pypi_0
52
+ entrypoints=0.4=pypi_0
53
+ executing=0.8.3=pypi_0
54
+ expat=2.4.8=h27087fc_0
55
+ fastapi=0.78.0=pypi_0
56
+ fastjsonschema=2.15.3=pypi_0
57
+ ffmpeg=4.4.2=habc3f16_0
58
+ ffmpy=0.3.0=pypi_0
59
+ fftw=3.3.10=nompi_h77c792f_102
60
+ fire=0.4.0=pypi_0
61
+ font-ttf-dejavu-sans-mono=2.37=hab24e00_0
62
+ font-ttf-inconsolata=3.000=h77eed37_0
63
+ font-ttf-source-code-pro=2.038=h77eed37_0
64
+ font-ttf-ubuntu=0.83=hab24e00_0
65
+ fontconfig=2.14.0=h8e229c2_0
66
+ fonts-conda-ecosystem=1=0
67
+ fonts-conda-forge=1=0
68
+ fonttools=4.33.3=pypi_0
69
+ freeglut=3.2.2=h9c3ff4c_1
70
+ freetype=2.11.0=h70c0345_0
71
+ frozenlist=1.3.0=pypi_0
72
+ fsspec=2022.5.0=pyhd8ed1ab_0
73
+ ftfy=6.1.1=pypi_0
74
+ gettext=0.19.8.1=h73d1719_1008
75
+ giflib=5.2.1=h7b6447c_0
76
+ glib=2.70.2=h780b84a_4
77
+ glib-tools=2.70.2=h780b84a_4
78
+ gmp=6.2.1=h295c915_3
79
+ gnutls=3.7.6=hb5d6004_1
80
+ gradio=3.0.24=pypi_0
81
+ graphite2=1.3.13=h58526e2_1001
82
+ gst-plugins-base=1.20.3=hf6a322e_0
83
+ gstreamer=1.20.3=hd4edc92_0
84
+ h11=0.12.0=pypi_0
85
+ harfbuzz=4.4.1=hf9f4e7c_0
86
+ hdf5=1.12.1=nompi_h2386368_104
87
+ httpcore=0.15.0=pypi_0
88
+ httpx=0.23.0=pypi_0
89
+ icu=70.1=h27087fc_0
90
+ idna=3.3=pyhd3eb1b0_0
91
+ imagecodecs=2022.2.22=py310h3ac3b6e_6
92
+ imageio=2.19.3=pyhcf75d05_0
93
+ intel-openmp=2021.4.0=h06a4308_3561
94
+ ipykernel=6.13.0=pypi_0
95
+ ipython=8.4.0=pypi_0
96
+ ipython-genutils=0.2.0=pypi_0
97
+ jack=1.9.18=h8c3723f_1002
98
+ jasper=2.0.33=ha77e612_0
99
+ jedi=0.18.1=pypi_0
100
+ jinja2=3.1.2=pypi_0
101
+ joblib=1.1.0=pyhd8ed1ab_0
102
+ jpeg=9e=h7f8727e_0
103
+ json5=0.9.8=pypi_0
104
+ jsonschema=4.6.0=pypi_0
105
+ jupyter-client=7.3.1=pypi_0
106
+ jupyter-core=4.10.0=pypi_0
107
+ jupyter-server=1.17.0=pypi_0
108
+ jupyterlab=3.4.2=pypi_0
109
+ jupyterlab-pygments=0.2.2=pypi_0
110
+ jupyterlab-server=2.14.0=pypi_0
111
+ jxrlib=1.1=h7f98852_2
112
+ keyutils=1.6.1=h166bdaf_0
113
+ kiwisolver=1.4.2=pypi_0
114
+ kornia=0.6.5=pypi_0
115
+ krb5=1.19.3=h3790be6_0
116
+ lame=3.100=h7b6447c_0
117
+ lcms2=2.12=h3be6417_0
118
+ ld_impl_linux-64=2.38=h1181459_1
119
+ lerc=3.0=h9c3ff4c_0
120
+ libaec=1.0.6=h9c3ff4c_0
121
+ libavif=0.10.1=h166bdaf_0
122
+ libblas=3.9.0=12_linux64_mkl
123
+ libbrotlicommon=1.0.9=h166bdaf_7
124
+ libbrotlidec=1.0.9=h166bdaf_7
125
+ libbrotlienc=1.0.9=h166bdaf_7
126
+ libcap=2.64=ha37c62d_0
127
+ libcblas=3.9.0=12_linux64_mkl
128
+ libclang=14.0.6=default_h2e3cab8_0
129
+ libclang13=14.0.6=default_h3a83d3e_0
130
+ libcups=2.3.3=hf5a7f15_1
131
+ libcurl=7.83.1=h7bff187_0
132
+ libdb=6.2.32=h9c3ff4c_0
133
+ libdeflate=1.12=h166bdaf_0
134
+ libdrm=2.4.112=h166bdaf_0
135
+ libedit=3.1.20191231=he28a2e2_2
136
+ libev=4.33=h516909a_1
137
+ libevent=2.1.10=h9b69904_4
138
+ libffi=3.4.2=h7f98852_5
139
+ libflac=1.3.4=h27087fc_0
140
+ libgcc-ng=12.1.0=h8d9b700_16
141
+ libgfortran-ng=12.1.0=h69a702a_16
142
+ libgfortran5=12.1.0=hdcd56e2_16
143
+ libglib=2.70.2=h174f98d_4
144
+ libglu=9.0.0=he1b5a44_1001
145
+ libiconv=1.16=h7f8727e_2
146
+ libidn2=2.3.2=h7f8727e_0
147
+ liblapack=3.9.0=12_linux64_mkl
148
+ liblapacke=3.9.0=12_linux64_mkl
149
+ libllvm14=14.0.6=he0ac6c6_0
150
+ libnghttp2=1.47.0=h727a467_0
151
+ libnsl=2.0.0=h7f98852_0
152
+ libogg=1.3.4=h7f98852_1
153
+ libopencv=4.5.5=py310hcb97b83_13
154
+ libopus=1.3.1=h7f98852_1
155
+ libpciaccess=0.16=h516909a_0
156
+ libpng=1.6.37=hbc83047_0
157
+ libpq=14.4=hd77ab85_0
158
+ libprotobuf=3.20.1=h6239696_0
159
+ libsndfile=1.0.31=h9c3ff4c_1
160
+ libssh2=1.10.0=ha56f1ee_2
161
+ libstdcxx-ng=12.1.0=ha89aaad_16
162
+ libtasn1=4.16.0=h27cfd23_0
163
+ libtiff=4.4.0=hc85c160_1
164
+ libtool=2.4.6=h9c3ff4c_1008
165
+ libudev1=249=h166bdaf_4
166
+ libunistring=0.9.10=h27cfd23_0
167
+ libuuid=2.32.1=h7f98852_1000
168
+ libuv=1.40.0=h7b6447c_0
169
+ libva=2.15.0=h166bdaf_0
170
+ libvorbis=1.3.7=h9c3ff4c_0
171
+ libvpx=1.11.0=h9c3ff4c_3
172
+ libwebp=1.2.2=h55f646e_0
173
+ libwebp-base=1.2.2=h7f8727e_0
174
+ libxcb=1.13=h7f98852_1004
175
+ libxkbcommon=1.0.3=he3ba5ed_0
176
+ libxml2=2.9.14=h22db469_3
177
+ libzlib=1.2.12=h166bdaf_1
178
+ libzopfli=1.0.3=h9c3ff4c_0
179
+ linkify-it-py=1.0.3=pypi_0
180
+ llvm-openmp=14.0.4=he0ac6c6_0
181
+ locket=1.0.0=pyhd8ed1ab_0
182
+ lz4-c=1.9.3=h295c915_1
183
+ markdown-it-py=2.1.0=pypi_0
184
+ markupsafe=2.1.1=pypi_0
185
+ matplotlib=3.5.2=pypi_0
186
+ matplotlib-inline=0.1.3=pypi_0
187
+ mdit-py-plugins=0.3.0=pypi_0
188
+ mdurl=0.1.1=pypi_0
189
+ mistune=0.8.4=pypi_0
190
+ mkl=2021.4.0=h06a4308_640
191
+ mkl-service=2.4.0=py310h7f8727e_0
192
+ mkl_fft=1.3.1=py310hd6ae3a3_0
193
+ mkl_random=1.2.2=py310h00e6091_0
194
+ mltk=0.0.5=pypi_0
195
+ monotonic=1.6=pypi_0
196
+ multidict=6.0.2=pypi_0
197
+ munch=2.5.0=pypi_0
198
+ mysql-common=8.0.29=haf5c9bc_1
199
+ mysql-libs=8.0.29=h28c427c_1
200
+ nbclassic=0.3.7=pypi_0
201
+ nbclient=0.6.4=pypi_0
202
+ nbconvert=6.5.0=pypi_0
203
+ nbformat=5.4.0=pypi_0
204
+ ncurses=6.3=h7f8727e_2
205
+ nest-asyncio=1.5.5=pypi_0
206
+ nettle=3.7.3=hbbd107a_1
207
+ networkx=2.8.4=pyhd8ed1ab_0
208
+ nltk=3.7=pypi_0
209
+ notebook=6.4.11=pypi_0
210
+ notebook-shim=0.1.0=pypi_0
211
+ nspr=4.32=h9c3ff4c_1
212
+ nss=3.78=h2350873_0
213
+ ntk=1.1.3=pypi_0
214
+ numpy=1.22.3=py310hfa59a62_0
215
+ numpy-base=1.22.3=py310h9585f30_0
216
+ opencv=4.5.5=py310hff52083_13
217
+ opencv-python=4.6.0.66=pypi_0
218
+ openh264=2.1.1=h4ff587b_0
219
+ openjpeg=2.4.0=hb52868f_1
220
+ openssl=1.1.1q=h166bdaf_0
221
+ orjson=3.7.7=pypi_0
222
+ packaging=21.3=pyhd8ed1ab_0
223
+ pandas=1.4.2=pypi_0
224
+ pandocfilters=1.5.0=pypi_0
225
+ paramiko=2.11.0=pypi_0
226
+ parso=0.8.3=pypi_0
227
+ partd=1.2.0=pyhd8ed1ab_0
228
+ pcre=8.45=h9c3ff4c_0
229
+ pexpect=4.8.0=pypi_0
230
+ pickleshare=0.7.5=pypi_0
231
+ pillow=9.0.1=py310h22f2fdc_0
232
+ pip=21.2.4=py310h06a4308_0
233
+ pixman=0.40.0=h36c2ea0_0
234
+ portaudio=19.6.0=h57a0ea0_5
235
+ prometheus-client=0.14.1=pypi_0
236
+ prompt-toolkit=3.0.29=pypi_0
237
+ psutil=5.9.1=pypi_0
238
+ pthread-stubs=0.4=h36c2ea0_1001
239
+ ptyprocess=0.7.0=pypi_0
240
+ pulseaudio=14.0=h7f54b18_8
241
+ pure-eval=0.2.2=pypi_0
242
+ py-opencv=4.5.5=py310hfdc917e_13
243
+ pycocotools=2.0.4=pypi_0
244
+ pycparser=2.21=pyhd3eb1b0_0
245
+ pycryptodome=3.15.0=pypi_0
246
+ pydantic=1.9.1=pypi_0
247
+ pydub=0.25.1=pypi_0
248
+ pygments=2.12.0=pypi_0
249
+ pynacl=1.5.0=pypi_0
250
+ pyopenssl=22.0.0=pyhd3eb1b0_0
251
+ pyparsing=3.0.9=pyhd8ed1ab_0
252
+ pyrsistent=0.18.1=pypi_0
253
+ pysocks=1.7.1=py310h06a4308_0
254
+ python=3.10.5=h582c2e5_0_cpython
255
+ python-dateutil=2.8.2=pypi_0
256
+ python-multipart=0.0.5=pypi_0
257
+ python_abi=3.10=2_cp310
258
+ pytorch=1.11.0=py3.10_cuda10.2_cudnn7.6.5_0
259
+ pytorch-mutex=1.0=cuda
260
+ pytz=2022.1=pypi_0
261
+ pywavelets=1.3.0=py310hde88566_1
262
+ pyyaml=6.0=py310h5764c6d_4
263
+ pyzmq=23.1.0=pypi_0
264
+ qt-main=5.15.4=ha5833f6_2
265
+ qudida=0.0.4=pyhd8ed1ab_0
266
+ readline=8.1.2=h7f8727e_1
267
+ regex=2022.6.2=pypi_0
268
+ requests=2.27.1=pyhd3eb1b0_0
269
+ rfc3986=1.5.0=pypi_0
270
+ scikit-image=0.19.3=py310h769672d_0
271
+ scikit-learn=1.1.1=py310hffb9edd_0
272
+ scipy=1.8.1=py310h7612f91_0
273
+ seaborn=0.11.2=pypi_0
274
+ send2trash=1.8.0=pypi_0
275
+ setuptools=61.2.0=py310h06a4308_0
276
+ six=1.16.0=pyhd3eb1b0_1
277
+ snappy=1.1.9=hbd366e4_1
278
+ sniffio=1.2.0=pypi_0
279
+ soupsieve=2.3.2.post1=pypi_0
280
+ sqlite=3.39.0=h4ff8645_0
281
+ stack-data=0.2.0=pypi_0
282
+ starlette=0.19.1=pypi_0
283
+ svt-av1=1.1.0=h27087fc_1
284
+ termcolor=1.1.0=pypi_0
285
+ terminado=0.15.0=pypi_0
286
+ threadpoolctl=3.1.0=pyh8a188c0_0
287
+ tifffile=2022.5.4=pyhd8ed1ab_0
288
+ tinycss2=1.1.1=pypi_0
289
+ tk=8.6.12=h1ccaba5_0
290
+ toolz=0.11.2=pyhd8ed1ab_0
291
+ torchaudio=0.11.0=py310_cu102
292
+ torchvision=0.12.0=py310_cu102
293
+ tornado=6.1=pypi_0
294
+ tqdm=4.64.0=pypi_0
295
+ traitlets=5.2.2.post1=pypi_0
296
+ typing-extensions=4.1.1=hd3eb1b0_0
297
+ typing_extensions=4.1.1=pyh06a4308_0
298
+ tzdata=2022a=hda174b7_0
299
+ uc-micro-py=1.0.1=pypi_0
300
+ urllib3=1.26.9=py310h06a4308_0
301
+ uvicorn=0.18.2=pypi_0
302
+ wcwidth=0.2.5=pypi_0
303
+ webencodings=0.5.1=pypi_0
304
+ websocket-client=1.3.2=pypi_0
305
+ wheel=0.37.1=pyhd3eb1b0_0
306
+ x264=1!161.3030=h7f98852_1
307
+ x265=3.5=h924138e_3
308
+ xcb-util=0.4.0=h166bdaf_0
309
+ xcb-util-image=0.4.0=h166bdaf_0
310
+ xcb-util-keysyms=0.4.0=h166bdaf_0
311
+ xcb-util-renderutil=0.3.9=h166bdaf_0
312
+ xcb-util-wm=0.4.1=h166bdaf_0
313
+ xorg-fixesproto=5.0=h7f98852_1002
314
+ xorg-inputproto=2.3.2=h7f98852_1002
315
+ xorg-kbproto=1.0.7=h7f98852_1002
316
+ xorg-libice=1.0.10=h7f98852_0
317
+ xorg-libsm=1.2.3=hd9c2040_1000
318
+ xorg-libx11=1.7.2=h7f98852_0
319
+ xorg-libxau=1.0.9=h7f98852_0
320
+ xorg-libxdmcp=1.1.3=h7f98852_0
321
+ xorg-libxext=1.3.4=h7f98852_1
322
+ xorg-libxfixes=5.0.3=h7f98852_1004
323
+ xorg-libxi=1.7.10=h7f98852_0
324
+ xorg-libxrender=0.9.10=h7f98852_1003
325
+ xorg-renderproto=0.11.1=h7f98852_1002
326
+ xorg-xextproto=7.3.0=h7f98852_1002
327
+ xorg-xproto=7.0.31=h7f98852_1007
328
+ xz=5.2.5=h7f8727e_1
329
+ yaml=0.2.5=h7f98852_2
330
+ yarl=1.7.2=pypi_0
331
+ zfp=0.5.5=h9c3ff4c_8
332
+ zlib=1.2.12=h166bdaf_1
333
+ zlib-ng=2.0.6=h166bdaf_0
334
+ zstd=1.5.2=ha4553b6_0
src/.gitkeep ADDED
File without changes
src/__pycache__/ds.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
src/__pycache__/losses.cpython-310.pyc ADDED
Binary file (4.17 kB). View file
 
src/__pycache__/networks_SRGAN.cpython-310.pyc ADDED
Binary file (6.99 kB). View file
 
src/__pycache__/utils.cpython-310.pyc ADDED
Binary file (34 kB). View file
 
src/app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import matplotlib.pyplot as plt
4
+ from matplotlib import cm
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torchvision.models as models
10
+ from torch.utils.data import Dataset, DataLoader
11
+ from torchvision import transforms
12
+ from torchvision.transforms.functional import InterpolationMode as IMode
13
+
14
+ from PIL import Image
15
+
16
+ from ds import *
17
+ from losses import *
18
+ from networks_SRGAN import *
19
+ from utils import *
20
+
21
+
22
+ NetG = Generator()
23
+ model_parameters = filter(lambda p: True, NetG.parameters())
24
+ params = sum([np.prod(p.size()) for p in model_parameters])
25
+ print("Number of Parameters:",params)
26
+ NetC = BayesCap(in_channels=3, out_channels=3)
27
+
28
+
29
+ NetG = Generator()
30
+ NetG.load_state_dict(torch.load('../ckpt/srgan-ImageNet-bc347d67.pth', map_location='cuda:0'))
31
+ NetG.to('cuda')
32
+ NetG.eval()
33
+
34
+ NetC = BayesCap(in_channels=3, out_channels=3)
35
+ NetC.load_state_dict(torch.load('../ckpt/BayesCap_SRGAN_best.pth', map_location='cuda:0'))
36
+ NetC.to('cuda')
37
+ NetC.eval()
38
+
39
+ def tensor01_to_pil(xt):
40
+ r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
41
+ return r
42
+
43
+
44
+ def predict(img):
45
+ """
46
+ img: image
47
+ """
48
+ image_size = (256,256)
49
+ upscale_factor = 4
50
+ lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
51
+ # lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
52
+
53
+ img = Image.fromarray(np.array(img))
54
+ img = lr_transforms(img)
55
+ lr_tensor = utils.image2tensor(img, range_norm=False, half=False)
56
+
57
+ device = 'cuda'
58
+ dtype = torch.cuda.FloatTensor
59
+ xLR = lr_tensor.to(device).unsqueeze(0)
60
+ xLR = xLR.type(dtype)
61
+ # pass them through the network
62
+ with torch.no_grad():
63
+ xSR = NetG(xLR)
64
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
65
+
66
+ a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
67
+ b_map = xSRC_beta[0].to('cpu').data
68
+ u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
69
+
70
+
71
+ x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
72
+
73
+ x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
74
+
75
+ #im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
76
+
77
+ a_map = torch.clamp(a_map, min=0, max=0.1)
78
+ a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
79
+ x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
80
+
81
+ b_map = torch.clamp(b_map, min=0.45, max=0.75)
82
+ b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
83
+ x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
84
+
85
+ u_map = torch.clamp(u_map, min=0, max=0.15)
86
+ u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
87
+ x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
88
+
89
+ return x_LR, x_mean, x_alpha, x_beta, x_uncer
90
+
91
+ import gradio as gr
92
+
93
+ title = "BayesCap"
94
+ description = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks (ECCV 2022)"
95
+ article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"
96
+
97
+
98
+ gr.Interface(
99
+ fn=predict,
100
+ inputs=gr.inputs.Image(type='pil', label="Orignal"),
101
+ outputs=[
102
+ gr.outputs.Image(type='pil', label="Low-res"),
103
+ gr.outputs.Image(type='pil', label="Super-res"),
104
+ gr.outputs.Image(type='pil', label="Alpha"),
105
+ gr.outputs.Image(type='pil', label="Beta"),
106
+ gr.outputs.Image(type='pil', label="Uncertainty")
107
+ ],
108
+ title=title,
109
+ description=description,
110
+ article=article,
111
+ examples=[
112
+ ["../demo_examples/baby.png"],
113
+ ["../demo_examples/bird.png"]
114
+ ]
115
+ ).launch(share=True)
src/ds.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division, print_function
2
+
3
+ import random
4
+ import copy
5
+ import io
6
+ import os
7
+ import numpy as np
8
+ from PIL import Image
9
+ import skimage.transform
10
+ from collections import Counter
11
+
12
+
13
+ import torch
14
+ import torch.utils.data as data
15
+ from torch import Tensor
16
+ from torch.utils.data import Dataset
17
+ from torchvision import transforms
18
+ from torchvision.transforms.functional import InterpolationMode as IMode
19
+
20
+ import utils
21
+
22
+ class ImgDset(Dataset):
23
+ """Customize the data set loading function and prepare low/high resolution image data in advance.
24
+
25
+ Args:
26
+ dataroot (str): Training data set address
27
+ image_size (int): High resolution image size
28
+ upscale_factor (int): Image magnification
29
+ mode (str): Data set loading method, the training data set is for data enhancement,
30
+ and the verification data set is not for data enhancement
31
+
32
+ """
33
+
34
+ def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
35
+ super(ImgDset, self).__init__()
36
+ self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
37
+
38
+ if mode == "train":
39
+ self.hr_transforms = transforms.Compose([
40
+ transforms.RandomCrop(image_size),
41
+ transforms.RandomRotation(90),
42
+ transforms.RandomHorizontalFlip(0.5),
43
+ ])
44
+ else:
45
+ self.hr_transforms = transforms.Resize(image_size)
46
+
47
+ self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
48
+
49
+ def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
50
+ # Read a batch of image data
51
+ image = Image.open(self.filenames[batch_index])
52
+
53
+ # Transform image
54
+ hr_image = self.hr_transforms(image)
55
+ lr_image = self.lr_transforms(hr_image)
56
+
57
+ # Convert image data into Tensor stream format (PyTorch).
58
+ # Note: The range of input and output is between [0, 1]
59
+ lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False)
60
+ hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False)
61
+
62
+ return lr_tensor, hr_tensor
63
+
64
+ def __len__(self) -> int:
65
+ return len(self.filenames)
66
+
67
+
68
+ class PairedImages_w_nameList(Dataset):
69
+ '''
70
+ can act as supervised or un-supervised based on flists
71
+ '''
72
+ def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
73
+ self.flist1 = flist1
74
+ self.flist2 = flist2
75
+ self.transform1 = transform1
76
+ self.transform2 = transform2
77
+ self.do_aug = do_aug
78
+ def __getitem__(self, index):
79
+ impath1 = self.flist1[index]
80
+ img1 = Image.open(impath1).convert('RGB')
81
+ impath2 = self.flist2[index]
82
+ img2 = Image.open(impath2).convert('RGB')
83
+
84
+ img1 = utils.image2tensor(img1, range_norm=False, half=False)
85
+ img2 = utils.image2tensor(img2, range_norm=False, half=False)
86
+
87
+ if self.transform1 is not None:
88
+ img1 = self.transform1(img1)
89
+ if self.transform2 is not None:
90
+ img2 = self.transform2(img2)
91
+
92
+ return img1, img2
93
+ def __len__(self):
94
+ return len(self.flist1)
95
+
96
+ class PairedImages_w_nameList_npy(Dataset):
97
+ '''
98
+ can act as supervised or un-supervised based on flists
99
+ '''
100
+ def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
101
+ self.flist1 = flist1
102
+ self.flist2 = flist2
103
+ self.transform1 = transform1
104
+ self.transform2 = transform2
105
+ self.do_aug = do_aug
106
+ def __getitem__(self, index):
107
+ impath1 = self.flist1[index]
108
+ img1 = np.load(impath1)
109
+ impath2 = self.flist2[index]
110
+ img2 = np.load(impath2)
111
+
112
+ if self.transform1 is not None:
113
+ img1 = self.transform1(img1)
114
+ if self.transform2 is not None:
115
+ img2 = self.transform2(img2)
116
+
117
+ return img1, img2
118
+ def __len__(self):
119
+ return len(self.flist1)
120
+
121
+ # def call_paired():
122
+ # root1='./GOPRO_3840FPS_AVG_3-21/train/blur/'
123
+ # root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/'
124
+
125
+ # flist1=glob.glob(root1+'/*/*.png')
126
+ # flist2=glob.glob(root2+'/*/*.png')
127
+
128
+ # dset = PairedImages_w_nameList(root1,root2,flist1,flist2)
129
+
130
+ #### KITTI depth
131
+
132
+ def load_velodyne_points(filename):
133
+ """Load 3D point cloud from KITTI file format
134
+ (adapted from https://github.com/hunse/kitti)
135
+ """
136
+ points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4)
137
+ points[:, 3] = 1.0 # homogeneous
138
+ return points
139
+
140
+
141
+ def read_calib_file(path):
142
+ """Read KITTI calibration file
143
+ (from https://github.com/hunse/kitti)
144
+ """
145
+ float_chars = set("0123456789.e+- ")
146
+ data = {}
147
+ with open(path, 'r') as f:
148
+ for line in f.readlines():
149
+ key, value = line.split(':', 1)
150
+ value = value.strip()
151
+ data[key] = value
152
+ if float_chars.issuperset(value):
153
+ # try to cast to float array
154
+ try:
155
+ data[key] = np.array(list(map(float, value.split(' '))))
156
+ except ValueError:
157
+ # casting error: data[key] already eq. value, so pass
158
+ pass
159
+
160
+ return data
161
+
162
+
163
+ def sub2ind(matrixSize, rowSub, colSub):
164
+ """Convert row, col matrix subscripts to linear indices
165
+ """
166
+ m, n = matrixSize
167
+ return rowSub * (n-1) + colSub - 1
168
+
169
+
170
+ def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False):
171
+ """Generate a depth map from velodyne data
172
+ """
173
+ # load calibration files
174
+ cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt'))
175
+ velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt'))
176
+ velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis]))
177
+ velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0])))
178
+
179
+ # get image shape
180
+ im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32)
181
+
182
+ # compute projection matrix velodyne->image plane
183
+ R_cam2rect = np.eye(4)
184
+ R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3)
185
+ P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4)
186
+ P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam)
187
+
188
+ # load velodyne points and remove all behind image plane (approximation)
189
+ # each row of the velodyne data is forward, left, up, reflectance
190
+ velo = load_velodyne_points(velo_filename)
191
+ velo = velo[velo[:, 0] >= 0, :]
192
+
193
+ # project the points to the camera
194
+ velo_pts_im = np.dot(P_velo2im, velo.T).T
195
+ velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis]
196
+
197
+ if vel_depth:
198
+ velo_pts_im[:, 2] = velo[:, 0]
199
+
200
+ # check if in bounds
201
+ # use minus 1 to get the exact same value as KITTI matlab code
202
+ velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
203
+ velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
204
+ val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
205
+ val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0])
206
+ velo_pts_im = velo_pts_im[val_inds, :]
207
+
208
+ # project to image
209
+ depth = np.zeros((im_shape[:2]))
210
+ depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2]
211
+
212
+ # find the duplicate points and choose the closest depth
213
+ inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0])
214
+ dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
215
+ for dd in dupe_inds:
216
+ pts = np.where(inds == dd)[0]
217
+ x_loc = int(velo_pts_im[pts[0], 0])
218
+ y_loc = int(velo_pts_im[pts[0], 1])
219
+ depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
220
+ depth[depth < 0] = 0
221
+
222
+ return depth
223
+
224
+ def pil_loader(path):
225
+ # open path as file to avoid ResourceWarning
226
+ # (https://github.com/python-pillow/Pillow/issues/835)
227
+ with open(path, 'rb') as f:
228
+ with Image.open(f) as img:
229
+ return img.convert('RGB')
230
+
231
+
232
+ class MonoDataset(data.Dataset):
233
+ """Superclass for monocular dataloaders
234
+
235
+ Args:
236
+ data_path
237
+ filenames
238
+ height
239
+ width
240
+ frame_idxs
241
+ num_scales
242
+ is_train
243
+ img_ext
244
+ """
245
+ def __init__(self,
246
+ data_path,
247
+ filenames,
248
+ height,
249
+ width,
250
+ frame_idxs,
251
+ num_scales,
252
+ is_train=False,
253
+ img_ext='.jpg'):
254
+ super(MonoDataset, self).__init__()
255
+
256
+ self.data_path = data_path
257
+ self.filenames = filenames
258
+ self.height = height
259
+ self.width = width
260
+ self.num_scales = num_scales
261
+ self.interp = Image.ANTIALIAS
262
+
263
+ self.frame_idxs = frame_idxs
264
+
265
+ self.is_train = is_train
266
+ self.img_ext = img_ext
267
+
268
+ self.loader = pil_loader
269
+ self.to_tensor = transforms.ToTensor()
270
+
271
+ # We need to specify augmentations differently in newer versions of torchvision.
272
+ # We first try the newer tuple version; if this fails we fall back to scalars
273
+ try:
274
+ self.brightness = (0.8, 1.2)
275
+ self.contrast = (0.8, 1.2)
276
+ self.saturation = (0.8, 1.2)
277
+ self.hue = (-0.1, 0.1)
278
+ transforms.ColorJitter.get_params(
279
+ self.brightness, self.contrast, self.saturation, self.hue)
280
+ except TypeError:
281
+ self.brightness = 0.2
282
+ self.contrast = 0.2
283
+ self.saturation = 0.2
284
+ self.hue = 0.1
285
+
286
+ self.resize = {}
287
+ for i in range(self.num_scales):
288
+ s = 2 ** i
289
+ self.resize[i] = transforms.Resize((self.height // s, self.width // s),
290
+ interpolation=self.interp)
291
+
292
+ self.load_depth = self.check_depth()
293
+
294
+ def preprocess(self, inputs, color_aug):
295
+ """Resize colour images to the required scales and augment if required
296
+
297
+ We create the color_aug object in advance and apply the same augmentation to all
298
+ images in this item. This ensures that all images input to the pose network receive the
299
+ same augmentation.
300
+ """
301
+ for k in list(inputs):
302
+ frame = inputs[k]
303
+ if "color" in k:
304
+ n, im, i = k
305
+ for i in range(self.num_scales):
306
+ inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
307
+
308
+ for k in list(inputs):
309
+ f = inputs[k]
310
+ if "color" in k:
311
+ n, im, i = k
312
+ inputs[(n, im, i)] = self.to_tensor(f)
313
+ inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
314
+
315
+ def __len__(self):
316
+ return len(self.filenames)
317
+
318
+ def __getitem__(self, index):
319
+ """Returns a single training item from the dataset as a dictionary.
320
+
321
+ Values correspond to torch tensors.
322
+ Keys in the dictionary are either strings or tuples:
323
+
324
+ ("color", <frame_id>, <scale>) for raw colour images,
325
+ ("color_aug", <frame_id>, <scale>) for augmented colour images,
326
+ ("K", scale) or ("inv_K", scale) for camera intrinsics,
327
+ "stereo_T" for camera extrinsics, and
328
+ "depth_gt" for ground truth depth maps.
329
+
330
+ <frame_id> is either:
331
+ an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
332
+ or
333
+ "s" for the opposite image in the stereo pair.
334
+
335
+ <scale> is an integer representing the scale of the image relative to the fullsize image:
336
+ -1 images at native resolution as loaded from disk
337
+ 0 images resized to (self.width, self.height )
338
+ 1 images resized to (self.width // 2, self.height // 2)
339
+ 2 images resized to (self.width // 4, self.height // 4)
340
+ 3 images resized to (self.width // 8, self.height // 8)
341
+ """
342
+ inputs = {}
343
+
344
+ do_color_aug = self.is_train and random.random() > 0.5
345
+ do_flip = self.is_train and random.random() > 0.5
346
+
347
+ line = self.filenames[index].split()
348
+ folder = line[0]
349
+
350
+ if len(line) == 3:
351
+ frame_index = int(line[1])
352
+ else:
353
+ frame_index = 0
354
+
355
+ if len(line) == 3:
356
+ side = line[2]
357
+ else:
358
+ side = None
359
+
360
+ for i in self.frame_idxs:
361
+ if i == "s":
362
+ other_side = {"r": "l", "l": "r"}[side]
363
+ inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
364
+ else:
365
+ inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
366
+
367
+ # adjusting intrinsics to match each scale in the pyramid
368
+ for scale in range(self.num_scales):
369
+ K = self.K.copy()
370
+
371
+ K[0, :] *= self.width // (2 ** scale)
372
+ K[1, :] *= self.height // (2 ** scale)
373
+
374
+ inv_K = np.linalg.pinv(K)
375
+
376
+ inputs[("K", scale)] = torch.from_numpy(K)
377
+ inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
378
+
379
+ if do_color_aug:
380
+ color_aug = transforms.ColorJitter.get_params(
381
+ self.brightness, self.contrast, self.saturation, self.hue)
382
+ else:
383
+ color_aug = (lambda x: x)
384
+
385
+ self.preprocess(inputs, color_aug)
386
+
387
+ for i in self.frame_idxs:
388
+ del inputs[("color", i, -1)]
389
+ del inputs[("color_aug", i, -1)]
390
+
391
+ if self.load_depth:
392
+ depth_gt = self.get_depth(folder, frame_index, side, do_flip)
393
+ inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
394
+ inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
395
+
396
+ if "s" in self.frame_idxs:
397
+ stereo_T = np.eye(4, dtype=np.float32)
398
+ baseline_sign = -1 if do_flip else 1
399
+ side_sign = -1 if side == "l" else 1
400
+ stereo_T[0, 3] = side_sign * baseline_sign * 0.1
401
+
402
+ inputs["stereo_T"] = torch.from_numpy(stereo_T)
403
+
404
+ return inputs
405
+
406
+ def get_color(self, folder, frame_index, side, do_flip):
407
+ raise NotImplementedError
408
+
409
+ def check_depth(self):
410
+ raise NotImplementedError
411
+
412
+ def get_depth(self, folder, frame_index, side, do_flip):
413
+ raise NotImplementedError
414
+
415
+ class KITTIDataset(MonoDataset):
416
+ """Superclass for different types of KITTI dataset loaders
417
+ """
418
+ def __init__(self, *args, **kwargs):
419
+ super(KITTIDataset, self).__init__(*args, **kwargs)
420
+
421
+ # NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
422
+ # To normalize you need to scale the first row by 1 / image_width and the second row
423
+ # by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
424
+ # If your principal point is far from the center you might need to disable the horizontal
425
+ # flip augmentation.
426
+ self.K = np.array([[0.58, 0, 0.5, 0],
427
+ [0, 1.92, 0.5, 0],
428
+ [0, 0, 1, 0],
429
+ [0, 0, 0, 1]], dtype=np.float32)
430
+
431
+ self.full_res_shape = (1242, 375)
432
+ self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
433
+
434
+ def check_depth(self):
435
+ line = self.filenames[0].split()
436
+ scene_name = line[0]
437
+ frame_index = int(line[1])
438
+
439
+ velo_filename = os.path.join(
440
+ self.data_path,
441
+ scene_name,
442
+ "velodyne_points/data/{:010d}.bin".format(int(frame_index)))
443
+
444
+ return os.path.isfile(velo_filename)
445
+
446
+ def get_color(self, folder, frame_index, side, do_flip):
447
+ color = self.loader(self.get_image_path(folder, frame_index, side))
448
+
449
+ if do_flip:
450
+ color = color.transpose(Image.FLIP_LEFT_RIGHT)
451
+
452
+ return color
453
+
454
+
455
+ class KITTIDepthDataset(KITTIDataset):
456
+ """KITTI dataset which uses the updated ground truth depth maps
457
+ """
458
+ def __init__(self, *args, **kwargs):
459
+ super(KITTIDepthDataset, self).__init__(*args, **kwargs)
460
+
461
+ def get_image_path(self, folder, frame_index, side):
462
+ f_str = "{:010d}{}".format(frame_index, self.img_ext)
463
+ image_path = os.path.join(
464
+ self.data_path,
465
+ folder,
466
+ "image_0{}/data".format(self.side_map[side]),
467
+ f_str)
468
+ return image_path
469
+
470
+ def get_depth(self, folder, frame_index, side, do_flip):
471
+ f_str = "{:010d}.png".format(frame_index)
472
+ depth_path = os.path.join(
473
+ self.data_path,
474
+ folder,
475
+ "proj_depth/groundtruth/image_0{}".format(self.side_map[side]),
476
+ f_str)
477
+
478
+ depth_gt = Image.open(depth_path)
479
+ depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST)
480
+ depth_gt = np.array(depth_gt).astype(np.float32) / 256
481
+
482
+ if do_flip:
483
+ depth_gt = np.fliplr(depth_gt)
484
+
485
+ return depth_gt
src/flagged/Alpha/0.png ADDED
src/flagged/Beta/0.png ADDED
src/flagged/Low-res/0.png ADDED
src/flagged/Orignal/0.png ADDED
src/flagged/Super-res/0.png ADDED
src/flagged/Uncertainty/0.png ADDED
src/flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ 'Orignal','Low-res','Super-res','Alpha','Beta','Uncertainty','flag','username','timestamp'
2
+ 'Orignal/0.png','Low-res/0.png','Super-res/0.png','Alpha/0.png','Beta/0.png','Uncertainty/0.png','','','2022-07-09 14:01:12.964411'
src/losses.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from torch import Tensor
6
+
7
+ class ContentLoss(nn.Module):
8
+ """Constructs a content loss function based on the VGG19 network.
9
+ Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
10
+
11
+ Paper reference list:
12
+ -`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
13
+ -`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
14
+ -`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
15
+
16
+ """
17
+
18
+ def __init__(self) -> None:
19
+ super(ContentLoss, self).__init__()
20
+ # Load the VGG19 model trained on the ImageNet dataset.
21
+ vgg19 = models.vgg19(pretrained=True).eval()
22
+ # Extract the thirty-sixth layer output in the VGG19 model as the content loss.
23
+ self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
24
+ # Freeze model parameters.
25
+ for parameters in self.feature_extractor.parameters():
26
+ parameters.requires_grad = False
27
+
28
+ # The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
29
+ self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
30
+ self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
31
+
32
+ def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
33
+ # Standardized operations
34
+ sr = sr.sub(self.mean).div(self.std)
35
+ hr = hr.sub(self.mean).div(self.std)
36
+
37
+ # Find the feature map difference between the two images
38
+ loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
39
+
40
+ return loss
41
+
42
+
43
+ class GenGaussLoss(nn.Module):
44
+ def __init__(
45
+ self, reduction='mean',
46
+ alpha_eps = 1e-4, beta_eps=1e-4,
47
+ resi_min = 1e-4, resi_max=1e3
48
+ ) -> None:
49
+ super(GenGaussLoss, self).__init__()
50
+ self.reduction = reduction
51
+ self.alpha_eps = alpha_eps
52
+ self.beta_eps = beta_eps
53
+ self.resi_min = resi_min
54
+ self.resi_max = resi_max
55
+
56
+ def forward(
57
+ self,
58
+ mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
59
+ ):
60
+ one_over_alpha1 = one_over_alpha + self.alpha_eps
61
+ beta1 = beta + self.beta_eps
62
+
63
+ resi = torch.abs(mean - target)
64
+ # resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
65
+ resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
66
+ ## check if resi has nans
67
+ if torch.sum(resi != resi) > 0:
68
+ print('resi has nans!!')
69
+ return None
70
+
71
+ log_one_over_alpha = torch.log(one_over_alpha1)
72
+ log_beta = torch.log(beta1)
73
+ lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
74
+
75
+ if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
76
+ print('log_one_over_alpha has nan')
77
+ if torch.sum(lgamma_beta != lgamma_beta) > 0:
78
+ print('lgamma_beta has nan')
79
+ if torch.sum(log_beta != log_beta) > 0:
80
+ print('log_beta has nan')
81
+
82
+ l = resi - log_one_over_alpha + lgamma_beta - log_beta
83
+
84
+ if self.reduction == 'mean':
85
+ return l.mean()
86
+ elif self.reduction == 'sum':
87
+ return l.sum()
88
+ else:
89
+ print('Reduction not supported')
90
+ return None
91
+
92
+ class TempCombLoss(nn.Module):
93
+ def __init__(
94
+ self, reduction='mean',
95
+ alpha_eps = 1e-4, beta_eps=1e-4,
96
+ resi_min = 1e-4, resi_max=1e3
97
+ ) -> None:
98
+ super(TempCombLoss, self).__init__()
99
+ self.reduction = reduction
100
+ self.alpha_eps = alpha_eps
101
+ self.beta_eps = beta_eps
102
+ self.resi_min = resi_min
103
+ self.resi_max = resi_max
104
+
105
+ self.L_GenGauss = GenGaussLoss(
106
+ reduction=self.reduction,
107
+ alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
108
+ resi_min=self.resi_min, resi_max=self.resi_max
109
+ )
110
+ self.L_l1 = nn.L1Loss(reduction=self.reduction)
111
+
112
+ def forward(
113
+ self,
114
+ mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
115
+ T1: float, T2: float
116
+ ):
117
+ l1 = self.L_l1(mean, target)
118
+ l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
119
+ l = T1*l1 + T2*l2
120
+
121
+ return l
122
+
123
+
124
+ # x1 = torch.randn(4,3,32,32)
125
+ # x2 = torch.rand(4,3,32,32)
126
+ # x3 = torch.rand(4,3,32,32)
127
+ # x4 = torch.randn(4,3,32,32)
128
+
129
+ # L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
130
+ # L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
131
+ # print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2))
src/networks_SRGAN.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.models as models
5
+ from torch import Tensor
6
+
7
+ # __all__ = [
8
+ # "ResidualConvBlock",
9
+ # "Discriminator", "Generator",
10
+ # ]
11
+
12
+
13
+ class ResidualConvBlock(nn.Module):
14
+ """Implements residual conv function.
15
+
16
+ Args:
17
+ channels (int): Number of channels in the input image.
18
+ """
19
+
20
+ def __init__(self, channels: int) -> None:
21
+ super(ResidualConvBlock, self).__init__()
22
+ self.rcb = nn.Sequential(
23
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
24
+ nn.BatchNorm2d(channels),
25
+ nn.PReLU(),
26
+ nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
27
+ nn.BatchNorm2d(channels),
28
+ )
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ identity = x
32
+
33
+ out = self.rcb(x)
34
+ out = torch.add(out, identity)
35
+
36
+ return out
37
+
38
+
39
+ class Discriminator(nn.Module):
40
+ def __init__(self) -> None:
41
+ super(Discriminator, self).__init__()
42
+ self.features = nn.Sequential(
43
+ # input size. (3) x 96 x 96
44
+ nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False),
45
+ nn.LeakyReLU(0.2, True),
46
+ # state size. (64) x 48 x 48
47
+ nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
48
+ nn.BatchNorm2d(64),
49
+ nn.LeakyReLU(0.2, True),
50
+ nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
51
+ nn.BatchNorm2d(128),
52
+ nn.LeakyReLU(0.2, True),
53
+ # state size. (128) x 24 x 24
54
+ nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
55
+ nn.BatchNorm2d(128),
56
+ nn.LeakyReLU(0.2, True),
57
+ nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
58
+ nn.BatchNorm2d(256),
59
+ nn.LeakyReLU(0.2, True),
60
+ # state size. (256) x 12 x 12
61
+ nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
62
+ nn.BatchNorm2d(256),
63
+ nn.LeakyReLU(0.2, True),
64
+ nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
65
+ nn.BatchNorm2d(512),
66
+ nn.LeakyReLU(0.2, True),
67
+ # state size. (512) x 6 x 6
68
+ nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
69
+ nn.BatchNorm2d(512),
70
+ nn.LeakyReLU(0.2, True),
71
+ )
72
+
73
+ self.classifier = nn.Sequential(
74
+ nn.Linear(512 * 6 * 6, 1024),
75
+ nn.LeakyReLU(0.2, True),
76
+ nn.Linear(1024, 1),
77
+ )
78
+
79
+ def forward(self, x: Tensor) -> Tensor:
80
+ out = self.features(x)
81
+ out = torch.flatten(out, 1)
82
+ out = self.classifier(out)
83
+
84
+ return out
85
+
86
+
87
+ class Generator(nn.Module):
88
+ def __init__(self) -> None:
89
+ super(Generator, self).__init__()
90
+ # First conv layer.
91
+ self.conv_block1 = nn.Sequential(
92
+ nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)),
93
+ nn.PReLU(),
94
+ )
95
+
96
+ # Features trunk blocks.
97
+ trunk = []
98
+ for _ in range(16):
99
+ trunk.append(ResidualConvBlock(64))
100
+ self.trunk = nn.Sequential(*trunk)
101
+
102
+ # Second conv layer.
103
+ self.conv_block2 = nn.Sequential(
104
+ nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
105
+ nn.BatchNorm2d(64),
106
+ )
107
+
108
+ # Upscale conv block.
109
+ self.upsampling = nn.Sequential(
110
+ nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
111
+ nn.PixelShuffle(2),
112
+ nn.PReLU(),
113
+ nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
114
+ nn.PixelShuffle(2),
115
+ nn.PReLU(),
116
+ )
117
+
118
+ # Output layer.
119
+ self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4))
120
+
121
+ # Initialize neural network weights.
122
+ self._initialize_weights()
123
+
124
+ def forward(self, x: Tensor, dop=None) -> Tensor:
125
+ if not dop:
126
+ return self._forward_impl(x)
127
+ else:
128
+ return self._forward_w_dop_impl(x, dop)
129
+
130
+ # Support torch.script function.
131
+ def _forward_impl(self, x: Tensor) -> Tensor:
132
+ out1 = self.conv_block1(x)
133
+ out = self.trunk(out1)
134
+ out2 = self.conv_block2(out)
135
+ out = torch.add(out1, out2)
136
+ out = self.upsampling(out)
137
+ out = self.conv_block3(out)
138
+
139
+ return out
140
+
141
+ def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor:
142
+ out1 = self.conv_block1(x)
143
+ out = self.trunk(out1)
144
+ out2 = F.dropout2d(self.conv_block2(out), p=dop)
145
+ out = torch.add(out1, out2)
146
+ out = self.upsampling(out)
147
+ out = self.conv_block3(out)
148
+
149
+ return out
150
+
151
+ def _initialize_weights(self) -> None:
152
+ for module in self.modules():
153
+ if isinstance(module, nn.Conv2d):
154
+ nn.init.kaiming_normal_(module.weight)
155
+ if module.bias is not None:
156
+ nn.init.constant_(module.bias, 0)
157
+ elif isinstance(module, nn.BatchNorm2d):
158
+ nn.init.constant_(module.weight, 1)
159
+
160
+
161
+ #### BayesCap
162
+ class BayesCap(nn.Module):
163
+ def __init__(self, in_channels=3, out_channels=3) -> None:
164
+ super(BayesCap, self).__init__()
165
+ # First conv layer.
166
+ self.conv_block1 = nn.Sequential(
167
+ nn.Conv2d(
168
+ in_channels, 64,
169
+ kernel_size=9, stride=1, padding=4
170
+ ),
171
+ nn.PReLU(),
172
+ )
173
+
174
+ # Features trunk blocks.
175
+ trunk = []
176
+ for _ in range(16):
177
+ trunk.append(ResidualConvBlock(64))
178
+ self.trunk = nn.Sequential(*trunk)
179
+
180
+ # Second conv layer.
181
+ self.conv_block2 = nn.Sequential(
182
+ nn.Conv2d(
183
+ 64, 64,
184
+ kernel_size=3, stride=1, padding=1, bias=False
185
+ ),
186
+ nn.BatchNorm2d(64),
187
+ )
188
+
189
+ # Output layer.
190
+ self.conv_block3_mu = nn.Conv2d(
191
+ 64, out_channels=out_channels,
192
+ kernel_size=9, stride=1, padding=4
193
+ )
194
+ self.conv_block3_alpha = nn.Sequential(
195
+ nn.Conv2d(
196
+ 64, 64,
197
+ kernel_size=9, stride=1, padding=4
198
+ ),
199
+ nn.PReLU(),
200
+ nn.Conv2d(
201
+ 64, 64,
202
+ kernel_size=9, stride=1, padding=4
203
+ ),
204
+ nn.PReLU(),
205
+ nn.Conv2d(
206
+ 64, 1,
207
+ kernel_size=9, stride=1, padding=4
208
+ ),
209
+ nn.ReLU(),
210
+ )
211
+ self.conv_block3_beta = nn.Sequential(
212
+ nn.Conv2d(
213
+ 64, 64,
214
+ kernel_size=9, stride=1, padding=4
215
+ ),
216
+ nn.PReLU(),
217
+ nn.Conv2d(
218
+ 64, 64,
219
+ kernel_size=9, stride=1, padding=4
220
+ ),
221
+ nn.PReLU(),
222
+ nn.Conv2d(
223
+ 64, 1,
224
+ kernel_size=9, stride=1, padding=4
225
+ ),
226
+ nn.ReLU(),
227
+ )
228
+
229
+ # Initialize neural network weights.
230
+ self._initialize_weights()
231
+
232
+ def forward(self, x: Tensor) -> Tensor:
233
+ return self._forward_impl(x)
234
+
235
+ # Support torch.script function.
236
+ def _forward_impl(self, x: Tensor) -> Tensor:
237
+ out1 = self.conv_block1(x)
238
+ out = self.trunk(out1)
239
+ out2 = self.conv_block2(out)
240
+ out = out1 + out2
241
+ out_mu = self.conv_block3_mu(out)
242
+ out_alpha = self.conv_block3_alpha(out)
243
+ out_beta = self.conv_block3_beta(out)
244
+ return out_mu, out_alpha, out_beta
245
+
246
+ def _initialize_weights(self) -> None:
247
+ for module in self.modules():
248
+ if isinstance(module, nn.Conv2d):
249
+ nn.init.kaiming_normal_(module.weight)
250
+ if module.bias is not None:
251
+ nn.init.constant_(module.bias, 0)
252
+ elif isinstance(module, nn.BatchNorm2d):
253
+ nn.init.constant_(module.weight, 1)
254
+
255
+
256
+ class BayesCap_noID(nn.Module):
257
+ def __init__(self, in_channels=3, out_channels=3) -> None:
258
+ super(BayesCap_noID, self).__init__()
259
+ # First conv layer.
260
+ self.conv_block1 = nn.Sequential(
261
+ nn.Conv2d(
262
+ in_channels, 64,
263
+ kernel_size=9, stride=1, padding=4
264
+ ),
265
+ nn.PReLU(),
266
+ )
267
+
268
+ # Features trunk blocks.
269
+ trunk = []
270
+ for _ in range(16):
271
+ trunk.append(ResidualConvBlock(64))
272
+ self.trunk = nn.Sequential(*trunk)
273
+
274
+ # Second conv layer.
275
+ self.conv_block2 = nn.Sequential(
276
+ nn.Conv2d(
277
+ 64, 64,
278
+ kernel_size=3, stride=1, padding=1, bias=False
279
+ ),
280
+ nn.BatchNorm2d(64),
281
+ )
282
+
283
+ # Output layer.
284
+ # self.conv_block3_mu = nn.Conv2d(
285
+ # 64, out_channels=out_channels,
286
+ # kernel_size=9, stride=1, padding=4
287
+ # )
288
+ self.conv_block3_alpha = nn.Sequential(
289
+ nn.Conv2d(
290
+ 64, 64,
291
+ kernel_size=9, stride=1, padding=4
292
+ ),
293
+ nn.PReLU(),
294
+ nn.Conv2d(
295
+ 64, 64,
296
+ kernel_size=9, stride=1, padding=4
297
+ ),
298
+ nn.PReLU(),
299
+ nn.Conv2d(
300
+ 64, 1,
301
+ kernel_size=9, stride=1, padding=4
302
+ ),
303
+ nn.ReLU(),
304
+ )
305
+ self.conv_block3_beta = nn.Sequential(
306
+ nn.Conv2d(
307
+ 64, 64,
308
+ kernel_size=9, stride=1, padding=4
309
+ ),
310
+ nn.PReLU(),
311
+ nn.Conv2d(
312
+ 64, 64,
313
+ kernel_size=9, stride=1, padding=4
314
+ ),
315
+ nn.PReLU(),
316
+ nn.Conv2d(
317
+ 64, 1,
318
+ kernel_size=9, stride=1, padding=4
319
+ ),
320
+ nn.ReLU(),
321
+ )
322
+
323
+ # Initialize neural network weights.
324
+ self._initialize_weights()
325
+
326
+ def forward(self, x: Tensor) -> Tensor:
327
+ return self._forward_impl(x)
328
+
329
+ # Support torch.script function.
330
+ def _forward_impl(self, x: Tensor) -> Tensor:
331
+ out1 = self.conv_block1(x)
332
+ out = self.trunk(out1)
333
+ out2 = self.conv_block2(out)
334
+ out = out1 + out2
335
+ # out_mu = self.conv_block3_mu(out)
336
+ out_alpha = self.conv_block3_alpha(out)
337
+ out_beta = self.conv_block3_beta(out)
338
+ return out_alpha, out_beta
339
+
340
+ def _initialize_weights(self) -> None:
341
+ for module in self.modules():
342
+ if isinstance(module, nn.Conv2d):
343
+ nn.init.kaiming_normal_(module.weight)
344
+ if module.bias is not None:
345
+ nn.init.constant_(module.bias, 0)
346
+ elif isinstance(module, nn.BatchNorm2d):
347
+ nn.init.constant_(module.weight, 1)
src/networks_T1toT2.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import functools
5
+
6
+ ### components
7
+ class ResConv(nn.Module):
8
+ """
9
+ Residual convolutional block, where
10
+ convolutional block consists: (convolution => [BN] => ReLU) * 3
11
+ residual connection adds the input to the output
12
+ """
13
+ def __init__(self, in_channels, out_channels, mid_channels=None):
14
+ super().__init__()
15
+ if not mid_channels:
16
+ mid_channels = out_channels
17
+ self.double_conv = nn.Sequential(
18
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
19
+ nn.BatchNorm2d(mid_channels),
20
+ nn.ReLU(inplace=True),
21
+ nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
22
+ nn.BatchNorm2d(mid_channels),
23
+ nn.ReLU(inplace=True),
24
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
25
+ nn.BatchNorm2d(out_channels),
26
+ nn.ReLU(inplace=True)
27
+ )
28
+ self.double_conv1 = nn.Sequential(
29
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
30
+ nn.BatchNorm2d(out_channels),
31
+ nn.ReLU(inplace=True),
32
+ )
33
+ def forward(self, x):
34
+ x_in = self.double_conv1(x)
35
+ x1 = self.double_conv(x)
36
+ return self.double_conv(x) + x_in
37
+
38
+ class Down(nn.Module):
39
+ """Downscaling with maxpool then Resconv"""
40
+ def __init__(self, in_channels, out_channels):
41
+ super().__init__()
42
+ self.maxpool_conv = nn.Sequential(
43
+ nn.MaxPool2d(2),
44
+ ResConv(in_channels, out_channels)
45
+ )
46
+ def forward(self, x):
47
+ return self.maxpool_conv(x)
48
+
49
+ class Up(nn.Module):
50
+ """Upscaling then double conv"""
51
+ def __init__(self, in_channels, out_channels, bilinear=True):
52
+ super().__init__()
53
+ # if bilinear, use the normal convolutions to reduce the number of channels
54
+ if bilinear:
55
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
56
+ self.conv = ResConv(in_channels, out_channels, in_channels // 2)
57
+ else:
58
+ self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
59
+ self.conv = ResConv(in_channels, out_channels)
60
+ def forward(self, x1, x2):
61
+ x1 = self.up(x1)
62
+ # input is CHW
63
+ diffY = x2.size()[2] - x1.size()[2]
64
+ diffX = x2.size()[3] - x1.size()[3]
65
+ x1 = F.pad(
66
+ x1,
67
+ [
68
+ diffX // 2, diffX - diffX // 2,
69
+ diffY // 2, diffY - diffY // 2
70
+ ]
71
+ )
72
+ # if you have padding issues, see
73
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
74
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
75
+ x = torch.cat([x2, x1], dim=1)
76
+ return self.conv(x)
77
+
78
+ class OutConv(nn.Module):
79
+ def __init__(self, in_channels, out_channels):
80
+ super(OutConv, self).__init__()
81
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
82
+ def forward(self, x):
83
+ # return F.relu(self.conv(x))
84
+ return self.conv(x)
85
+
86
+ ##### The composite networks
87
+ class UNet(nn.Module):
88
+ def __init__(self, n_channels, out_channels, bilinear=True):
89
+ super(UNet, self).__init__()
90
+ self.n_channels = n_channels
91
+ self.out_channels = out_channels
92
+ self.bilinear = bilinear
93
+ ####
94
+ self.inc = ResConv(n_channels, 64)
95
+ self.down1 = Down(64, 128)
96
+ self.down2 = Down(128, 256)
97
+ self.down3 = Down(256, 512)
98
+ factor = 2 if bilinear else 1
99
+ self.down4 = Down(512, 1024 // factor)
100
+ self.up1 = Up(1024, 512 // factor, bilinear)
101
+ self.up2 = Up(512, 256 // factor, bilinear)
102
+ self.up3 = Up(256, 128 // factor, bilinear)
103
+ self.up4 = Up(128, 64, bilinear)
104
+ self.outc = OutConv(64, out_channels)
105
+ def forward(self, x):
106
+ x1 = self.inc(x)
107
+ x2 = self.down1(x1)
108
+ x3 = self.down2(x2)
109
+ x4 = self.down3(x3)
110
+ x5 = self.down4(x4)
111
+ x = self.up1(x5, x4)
112
+ x = self.up2(x, x3)
113
+ x = self.up3(x, x2)
114
+ x = self.up4(x, x1)
115
+ y = self.outc(x)
116
+ return y
117
+
118
+ class CasUNet(nn.Module):
119
+ def __init__(self, n_unet, io_channels, bilinear=True):
120
+ super(CasUNet, self).__init__()
121
+ self.n_unet = n_unet
122
+ self.io_channels = io_channels
123
+ self.bilinear = bilinear
124
+ ####
125
+ self.unet_list = nn.ModuleList()
126
+ for i in range(self.n_unet):
127
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
128
+ def forward(self, x, dop=None):
129
+ y = x
130
+ for i in range(self.n_unet):
131
+ if i==0:
132
+ if dop is not None:
133
+ y = F.dropout2d(self.unet_list[i](y), p=dop)
134
+ else:
135
+ y = self.unet_list[i](y)
136
+ else:
137
+ y = self.unet_list[i](y+x)
138
+ return y
139
+
140
+ class CasUNet_2head(nn.Module):
141
+ def __init__(self, n_unet, io_channels, bilinear=True):
142
+ super(CasUNet_2head, self).__init__()
143
+ self.n_unet = n_unet
144
+ self.io_channels = io_channels
145
+ self.bilinear = bilinear
146
+ ####
147
+ self.unet_list = nn.ModuleList()
148
+ for i in range(self.n_unet):
149
+ if i != self.n_unet-1:
150
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
151
+ else:
152
+ self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear))
153
+ def forward(self, x):
154
+ y = x
155
+ for i in range(self.n_unet):
156
+ if i==0:
157
+ y = self.unet_list[i](y)
158
+ else:
159
+ y = self.unet_list[i](y+x)
160
+ y_mean, y_sigma = y[0], y[1]
161
+ return y_mean, y_sigma
162
+
163
+ class CasUNet_3head(nn.Module):
164
+ def __init__(self, n_unet, io_channels, bilinear=True):
165
+ super(CasUNet_3head, self).__init__()
166
+ self.n_unet = n_unet
167
+ self.io_channels = io_channels
168
+ self.bilinear = bilinear
169
+ ####
170
+ self.unet_list = nn.ModuleList()
171
+ for i in range(self.n_unet):
172
+ if i != self.n_unet-1:
173
+ self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
174
+ else:
175
+ self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear))
176
+ def forward(self, x):
177
+ y = x
178
+ for i in range(self.n_unet):
179
+ if i==0:
180
+ y = self.unet_list[i](y)
181
+ else:
182
+ y = self.unet_list[i](y+x)
183
+ y_mean, y_alpha, y_beta = y[0], y[1], y[2]
184
+ return y_mean, y_alpha, y_beta
185
+
186
+ class UNet_2head(nn.Module):
187
+ def __init__(self, n_channels, out_channels, bilinear=True):
188
+ super(UNet_2head, self).__init__()
189
+ self.n_channels = n_channels
190
+ self.out_channels = out_channels
191
+ self.bilinear = bilinear
192
+ ####
193
+ self.inc = ResConv(n_channels, 64)
194
+ self.down1 = Down(64, 128)
195
+ self.down2 = Down(128, 256)
196
+ self.down3 = Down(256, 512)
197
+ factor = 2 if bilinear else 1
198
+ self.down4 = Down(512, 1024 // factor)
199
+ self.up1 = Up(1024, 512 // factor, bilinear)
200
+ self.up2 = Up(512, 256 // factor, bilinear)
201
+ self.up3 = Up(256, 128 // factor, bilinear)
202
+ self.up4 = Up(128, 64, bilinear)
203
+ #per pixel multiple channels may exist
204
+ self.out_mean = OutConv(64, out_channels)
205
+ #variance will always be a single number for a pixel
206
+ self.out_var = nn.Sequential(
207
+ OutConv(64, 128),
208
+ OutConv(128, 1),
209
+ )
210
+ def forward(self, x):
211
+ x1 = self.inc(x)
212
+ x2 = self.down1(x1)
213
+ x3 = self.down2(x2)
214
+ x4 = self.down3(x3)
215
+ x5 = self.down4(x4)
216
+ x = self.up1(x5, x4)
217
+ x = self.up2(x, x3)
218
+ x = self.up3(x, x2)
219
+ x = self.up4(x, x1)
220
+ y_mean, y_var = self.out_mean(x), self.out_var(x)
221
+ return y_mean, y_var
222
+
223
+ class UNet_3head(nn.Module):
224
+ def __init__(self, n_channels, out_channels, bilinear=True):
225
+ super(UNet_3head, self).__init__()
226
+ self.n_channels = n_channels
227
+ self.out_channels = out_channels
228
+ self.bilinear = bilinear
229
+ ####
230
+ self.inc = ResConv(n_channels, 64)
231
+ self.down1 = Down(64, 128)
232
+ self.down2 = Down(128, 256)
233
+ self.down3 = Down(256, 512)
234
+ factor = 2 if bilinear else 1
235
+ self.down4 = Down(512, 1024 // factor)
236
+ self.up1 = Up(1024, 512 // factor, bilinear)
237
+ self.up2 = Up(512, 256 // factor, bilinear)
238
+ self.up3 = Up(256, 128 // factor, bilinear)
239
+ self.up4 = Up(128, 64, bilinear)
240
+ #per pixel multiple channels may exist
241
+ self.out_mean = OutConv(64, out_channels)
242
+ #variance will always be a single number for a pixel
243
+ self.out_alpha = nn.Sequential(
244
+ OutConv(64, 128),
245
+ OutConv(128, 1),
246
+ nn.ReLU()
247
+ )
248
+ self.out_beta = nn.Sequential(
249
+ OutConv(64, 128),
250
+ OutConv(128, 1),
251
+ nn.ReLU()
252
+ )
253
+ def forward(self, x):
254
+ x1 = self.inc(x)
255
+ x2 = self.down1(x1)
256
+ x3 = self.down2(x2)
257
+ x4 = self.down3(x3)
258
+ x5 = self.down4(x4)
259
+ x = self.up1(x5, x4)
260
+ x = self.up2(x, x3)
261
+ x = self.up3(x, x2)
262
+ x = self.up4(x, x1)
263
+ y_mean, y_alpha, y_beta = self.out_mean(x), \
264
+ self.out_alpha(x), self.out_beta(x)
265
+ return y_mean, y_alpha, y_beta
266
+
267
+ class ResidualBlock(nn.Module):
268
+ def __init__(self, in_features):
269
+ super(ResidualBlock, self).__init__()
270
+ conv_block = [
271
+ nn.ReflectionPad2d(1),
272
+ nn.Conv2d(in_features, in_features, 3),
273
+ nn.InstanceNorm2d(in_features),
274
+ nn.ReLU(inplace=True),
275
+ nn.ReflectionPad2d(1),
276
+ nn.Conv2d(in_features, in_features, 3),
277
+ nn.InstanceNorm2d(in_features)
278
+ ]
279
+ self.conv_block = nn.Sequential(*conv_block)
280
+ def forward(self, x):
281
+ return x + self.conv_block(x)
282
+
283
+ class Generator(nn.Module):
284
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9):
285
+ super(Generator, self).__init__()
286
+ # Initial convolution block
287
+ model = [
288
+ nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7),
289
+ nn.InstanceNorm2d(64), nn.ReLU(inplace=True)
290
+ ]
291
+ # Downsampling
292
+ in_features = 64
293
+ out_features = in_features*2
294
+ for _ in range(2):
295
+ model += [
296
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
297
+ nn.InstanceNorm2d(out_features),
298
+ nn.ReLU(inplace=True)
299
+ ]
300
+ in_features = out_features
301
+ out_features = in_features*2
302
+ # Residual blocks
303
+ for _ in range(n_residual_blocks):
304
+ model += [ResidualBlock(in_features)]
305
+ # Upsampling
306
+ out_features = in_features//2
307
+ for _ in range(2):
308
+ model += [
309
+ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
310
+ nn.InstanceNorm2d(out_features),
311
+ nn.ReLU(inplace=True)
312
+ ]
313
+ in_features = out_features
314
+ out_features = in_features//2
315
+ # Output layer
316
+ model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]
317
+ self.model = nn.Sequential(*model)
318
+ def forward(self, x):
319
+ return self.model(x)
320
+
321
+
322
+ class ResnetGenerator(nn.Module):
323
+ """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
324
+ We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
325
+ """
326
+
327
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
328
+ """Construct a Resnet-based generator
329
+ Parameters:
330
+ input_nc (int) -- the number of channels in input images
331
+ output_nc (int) -- the number of channels in output images
332
+ ngf (int) -- the number of filters in the last conv layer
333
+ norm_layer -- normalization layer
334
+ use_dropout (bool) -- if use dropout layers
335
+ n_blocks (int) -- the number of ResNet blocks
336
+ padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
337
+ """
338
+ assert(n_blocks >= 0)
339
+ super(ResnetGenerator, self).__init__()
340
+ if type(norm_layer) == functools.partial:
341
+ use_bias = norm_layer.func == nn.InstanceNorm2d
342
+ else:
343
+ use_bias = norm_layer == nn.InstanceNorm2d
344
+
345
+ model = [nn.ReflectionPad2d(3),
346
+ nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
347
+ norm_layer(ngf),
348
+ nn.ReLU(True)]
349
+
350
+ n_downsampling = 2
351
+ for i in range(n_downsampling): # add downsampling layers
352
+ mult = 2 ** i
353
+ model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
354
+ norm_layer(ngf * mult * 2),
355
+ nn.ReLU(True)]
356
+
357
+ mult = 2 ** n_downsampling
358
+ for i in range(n_blocks): # add ResNet blocks
359
+
360
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
361
+
362
+ for i in range(n_downsampling): # add upsampling layers
363
+ mult = 2 ** (n_downsampling - i)
364
+ model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
365
+ kernel_size=3, stride=2,
366
+ padding=1, output_padding=1,
367
+ bias=use_bias),
368
+ norm_layer(int(ngf * mult / 2)),
369
+ nn.ReLU(True)]
370
+ model += [nn.ReflectionPad2d(3)]
371
+ model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
372
+ model += [nn.Tanh()]
373
+
374
+ self.model = nn.Sequential(*model)
375
+
376
+ def forward(self, input):
377
+ """Standard forward"""
378
+ return self.model(input)
379
+
380
+
381
+ class ResnetBlock(nn.Module):
382
+ """Define a Resnet block"""
383
+
384
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
385
+ """Initialize the Resnet block
386
+ A resnet block is a conv block with skip connections
387
+ We construct a conv block with build_conv_block function,
388
+ and implement skip connections in <forward> function.
389
+ Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
390
+ """
391
+ super(ResnetBlock, self).__init__()
392
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
393
+
394
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
395
+ """Construct a convolutional block.
396
+ Parameters:
397
+ dim (int) -- the number of channels in the conv layer.
398
+ padding_type (str) -- the name of padding layer: reflect | replicate | zero
399
+ norm_layer -- normalization layer
400
+ use_dropout (bool) -- if use dropout layers.
401
+ use_bias (bool) -- if the conv layer uses bias or not
402
+ Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
403
+ """
404
+ conv_block = []
405
+ p = 0
406
+ if padding_type == 'reflect':
407
+ conv_block += [nn.ReflectionPad2d(1)]
408
+ elif padding_type == 'replicate':
409
+ conv_block += [nn.ReplicationPad2d(1)]
410
+ elif padding_type == 'zero':
411
+ p = 1
412
+ else:
413
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
414
+
415
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
416
+ if use_dropout:
417
+ conv_block += [nn.Dropout(0.5)]
418
+
419
+ p = 0
420
+ if padding_type == 'reflect':
421
+ conv_block += [nn.ReflectionPad2d(1)]
422
+ elif padding_type == 'replicate':
423
+ conv_block += [nn.ReplicationPad2d(1)]
424
+ elif padding_type == 'zero':
425
+ p = 1
426
+ else:
427
+ raise NotImplementedError('padding [%s] is not implemented' % padding_type)
428
+ conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
429
+
430
+ return nn.Sequential(*conv_block)
431
+
432
+ def forward(self, x):
433
+ """Forward function (with skip connections)"""
434
+ out = x + self.conv_block(x) # add skip connections
435
+ return out
436
+
437
+ ### discriminator
438
+ class NLayerDiscriminator(nn.Module):
439
+ """Defines a PatchGAN discriminator"""
440
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
441
+ """Construct a PatchGAN discriminator
442
+ Parameters:
443
+ input_nc (int) -- the number of channels in input images
444
+ ndf (int) -- the number of filters in the last conv layer
445
+ n_layers (int) -- the number of conv layers in the discriminator
446
+ norm_layer -- normalization layer
447
+ """
448
+ super(NLayerDiscriminator, self).__init__()
449
+ if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
450
+ use_bias = norm_layer.func == nn.InstanceNorm2d
451
+ else:
452
+ use_bias = norm_layer == nn.InstanceNorm2d
453
+ kw = 4
454
+ padw = 1
455
+ sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
456
+ nf_mult = 1
457
+ nf_mult_prev = 1
458
+ for n in range(1, n_layers): # gradually increase the number of filters
459
+ nf_mult_prev = nf_mult
460
+ nf_mult = min(2 ** n, 8)
461
+ sequence += [
462
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
463
+ norm_layer(ndf * nf_mult),
464
+ nn.LeakyReLU(0.2, True)
465
+ ]
466
+ nf_mult_prev = nf_mult
467
+ nf_mult = min(2 ** n_layers, 8)
468
+ sequence += [
469
+ nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
470
+ norm_layer(ndf * nf_mult),
471
+ nn.LeakyReLU(0.2, True)
472
+ ]
473
+ sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
474
+ self.model = nn.Sequential(*sequence)
475
+ def forward(self, input):
476
+ """Standard forward."""
477
+ return self.model(input)
src/utils.py ADDED
@@ -0,0 +1,1273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Optional
3
+ import numpy as np
4
+ import os
5
+ import cv2
6
+ from glob import glob
7
+ from PIL import Image, ImageDraw
8
+ from tqdm import tqdm
9
+ import kornia
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ import albumentations as albu
13
+ import functools
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch import Tensor
19
+ import torchvision as tv
20
+ import torchvision.models as models
21
+ from torchvision import transforms
22
+ from torchvision.transforms import functional as F
23
+ from losses import TempCombLoss
24
+
25
+ ########### DeblurGAN function
26
+ def get_norm_layer(norm_type='instance'):
27
+ if norm_type == 'batch':
28
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
29
+ elif norm_type == 'instance':
30
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
31
+ else:
32
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
33
+ return norm_layer
34
+
35
+ def _array_to_batch(x):
36
+ x = np.transpose(x, (2, 0, 1))
37
+ x = np.expand_dims(x, 0)
38
+ return torch.from_numpy(x)
39
+
40
+ def get_normalize():
41
+ normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
42
+ normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
43
+
44
+ def process(a, b):
45
+ r = normalize(image=a, target=b)
46
+ return r['image'], r['target']
47
+
48
+ return process
49
+
50
+ def preprocess(x: np.ndarray, mask: Optional[np.ndarray]):
51
+ x, _ = get_normalize()(x, x)
52
+ if mask is None:
53
+ mask = np.ones_like(x, dtype=np.float32)
54
+ else:
55
+ mask = np.round(mask.astype('float32') / 255)
56
+
57
+ h, w, _ = x.shape
58
+ block_size = 32
59
+ min_height = (h // block_size + 1) * block_size
60
+ min_width = (w // block_size + 1) * block_size
61
+
62
+ pad_params = {'mode': 'constant',
63
+ 'constant_values': 0,
64
+ 'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
65
+ }
66
+ x = np.pad(x, **pad_params)
67
+ mask = np.pad(mask, **pad_params)
68
+
69
+ return map(_array_to_batch, (x, mask)), h, w
70
+
71
+ def postprocess(x: torch.Tensor) -> np.ndarray:
72
+ x, = x
73
+ x = x.detach().cpu().float().numpy()
74
+ x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
75
+ return x.astype('uint8')
76
+
77
+ def sorted_glob(pattern):
78
+ return sorted(glob(pattern))
79
+ ###########
80
+
81
+ def normalize(image: np.ndarray) -> np.ndarray:
82
+ """Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
83
+ Args:
84
+ image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
85
+ Returns:
86
+ Normalized image data. Data range [0, 1].
87
+ """
88
+ return image.astype(np.float64) / 255.0
89
+
90
+
91
+ def unnormalize(image: np.ndarray) -> np.ndarray:
92
+ """Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
93
+ Args:
94
+ image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
95
+ Returns:
96
+ Denormalized image data. Data range [0, 255].
97
+ """
98
+ return image.astype(np.float64) * 255.0
99
+
100
+
101
+ def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
102
+ """Convert ``PIL.Image`` to Tensor.
103
+ Args:
104
+ image (np.ndarray): The image data read by ``PIL.Image``
105
+ range_norm (bool): Scale [0, 1] data to between [-1, 1]
106
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
107
+ Returns:
108
+ Normalized image data
109
+ Examples:
110
+ >>> image = Image.open("image.bmp")
111
+ >>> tensor_image = image2tensor(image, range_norm=False, half=False)
112
+ """
113
+ tensor = F.to_tensor(image)
114
+
115
+ if range_norm:
116
+ tensor = tensor.mul_(2.0).sub_(1.0)
117
+ if half:
118
+ tensor = tensor.half()
119
+
120
+ return tensor
121
+
122
+
123
+ def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
124
+ """Converts ``torch.Tensor`` to ``PIL.Image``.
125
+ Args:
126
+ tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
127
+ range_norm (bool): Scale [-1, 1] data to between [0, 1]
128
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
129
+ Returns:
130
+ Convert image data to support PIL library
131
+ Examples:
132
+ >>> tensor = torch.randn([1, 3, 128, 128])
133
+ >>> image = tensor2image(tensor, range_norm=False, half=False)
134
+ """
135
+ if range_norm:
136
+ tensor = tensor.add_(1.0).div_(2.0)
137
+ if half:
138
+ tensor = tensor.half()
139
+
140
+ image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
141
+
142
+ return image
143
+
144
+
145
+ def convert_rgb_to_y(image: Any) -> Any:
146
+ """Convert RGB image or tensor image data to YCbCr(Y) format.
147
+ Args:
148
+ image: RGB image data read by ``PIL.Image''.
149
+ Returns:
150
+ Y image array data.
151
+ """
152
+ if type(image) == np.ndarray:
153
+ return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
154
+ elif type(image) == torch.Tensor:
155
+ if len(image.shape) == 4:
156
+ image = image.squeeze_(0)
157
+ return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
158
+ else:
159
+ raise Exception("Unknown Type", type(image))
160
+
161
+
162
+ def convert_rgb_to_ycbcr(image: Any) -> Any:
163
+ """Convert RGB image or tensor image data to YCbCr format.
164
+ Args:
165
+ image: RGB image data read by ``PIL.Image''.
166
+ Returns:
167
+ YCbCr image array data.
168
+ """
169
+ if type(image) == np.ndarray:
170
+ y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
171
+ cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256.
172
+ cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256.
173
+ return np.array([y, cb, cr]).transpose([1, 2, 0])
174
+ elif type(image) == torch.Tensor:
175
+ if len(image.shape) == 4:
176
+ image = image.squeeze(0)
177
+ y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
178
+ cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256.
179
+ cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256.
180
+ return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
181
+ else:
182
+ raise Exception("Unknown Type", type(image))
183
+
184
+
185
+ def convert_ycbcr_to_rgb(image: Any) -> Any:
186
+ """Convert YCbCr format image to RGB format.
187
+ Args:
188
+ image: YCbCr image data read by ``PIL.Image''.
189
+ Returns:
190
+ RGB image array data.
191
+ """
192
+ if type(image) == np.ndarray:
193
+ r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921
194
+ g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576
195
+ b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836
196
+ return np.array([r, g, b]).transpose([1, 2, 0])
197
+ elif type(image) == torch.Tensor:
198
+ if len(image.shape) == 4:
199
+ image = image.squeeze(0)
200
+ r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921
201
+ g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576
202
+ b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836
203
+ return torch.cat([r, g, b], 0).permute(1, 2, 0)
204
+ else:
205
+ raise Exception("Unknown Type", type(image))
206
+
207
+
208
+ def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
209
+ """Cut ``PIL.Image`` in the center area of the image.
210
+ Args:
211
+ lr: Low-resolution image data read by ``PIL.Image``.
212
+ hr: High-resolution image data read by ``PIL.Image``.
213
+ image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
214
+ upscale_factor (int): magnification factor.
215
+ Returns:
216
+ Randomly cropped low-resolution images and high-resolution images.
217
+ """
218
+ w, h = hr.size
219
+
220
+ left = (w - image_size) // 2
221
+ top = (h - image_size) // 2
222
+ right = left + image_size
223
+ bottom = top + image_size
224
+
225
+ lr = lr.crop((left // upscale_factor,
226
+ top // upscale_factor,
227
+ right // upscale_factor,
228
+ bottom // upscale_factor))
229
+ hr = hr.crop((left, top, right, bottom))
230
+
231
+ return lr, hr
232
+
233
+
234
+ def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
235
+ """Will ``PIL.Image`` randomly capture the specified area of the image.
236
+ Args:
237
+ lr: Low-resolution image data read by ``PIL.Image``.
238
+ hr: High-resolution image data read by ``PIL.Image``.
239
+ image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
240
+ upscale_factor (int): magnification factor.
241
+ Returns:
242
+ Randomly cropped low-resolution images and high-resolution images.
243
+ """
244
+ w, h = hr.size
245
+ left = torch.randint(0, w - image_size + 1, size=(1,)).item()
246
+ top = torch.randint(0, h - image_size + 1, size=(1,)).item()
247
+ right = left + image_size
248
+ bottom = top + image_size
249
+
250
+ lr = lr.crop((left // upscale_factor,
251
+ top // upscale_factor,
252
+ right // upscale_factor,
253
+ bottom // upscale_factor))
254
+ hr = hr.crop((left, top, right, bottom))
255
+
256
+ return lr, hr
257
+
258
+
259
+ def random_rotate(lr: Any, hr: Any, angle: int) -> [Any, Any]:
260
+ """Will ``PIL.Image`` randomly rotate the image.
261
+ Args:
262
+ lr: Low-resolution image data read by ``PIL.Image``.
263
+ hr: High-resolution image data read by ``PIL.Image``.
264
+ angle (int): rotation angle, clockwise and counterclockwise rotation.
265
+ Returns:
266
+ Randomly rotated low-resolution images and high-resolution images.
267
+ """
268
+ angle = random.choice((+angle, -angle))
269
+ lr = F.rotate(lr, angle)
270
+ hr = F.rotate(hr, angle)
271
+
272
+ return lr, hr
273
+
274
+
275
+ def random_horizontally_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
276
+ """Flip the ``PIL.Image`` image horizontally randomly.
277
+ Args:
278
+ lr: Low-resolution image data read by ``PIL.Image``.
279
+ hr: High-resolution image data read by ``PIL.Image``.
280
+ p (optional, float): rollover probability. (Default: 0.5)
281
+ Returns:
282
+ Low-resolution image and high-resolution image after random horizontal flip.
283
+ """
284
+ if torch.rand(1).item() > p:
285
+ lr = F.hflip(lr)
286
+ hr = F.hflip(hr)
287
+
288
+ return lr, hr
289
+
290
+
291
+ def random_vertically_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
292
+ """Turn the ``PIL.Image`` image upside down randomly.
293
+ Args:
294
+ lr: Low-resolution image data read by ``PIL.Image``.
295
+ hr: High-resolution image data read by ``PIL.Image``.
296
+ p (optional, float): rollover probability. (Default: 0.5)
297
+ Returns:
298
+ Randomly rotated up and down low-resolution images and high-resolution images.
299
+ """
300
+ if torch.rand(1).item() > p:
301
+ lr = F.vflip(lr)
302
+ hr = F.vflip(hr)
303
+
304
+ return lr, hr
305
+
306
+
307
+ def random_adjust_brightness(lr: Any, hr: Any) -> [Any, Any]:
308
+ """Set ``PIL.Image`` to randomly adjust the image brightness.
309
+ Args:
310
+ lr: Low-resolution image data read by ``PIL.Image``.
311
+ hr: High-resolution image data read by ``PIL.Image``.
312
+ Returns:
313
+ Low-resolution image and high-resolution image with randomly adjusted brightness.
314
+ """
315
+ # Randomly adjust the brightness gain range.
316
+ factor = random.uniform(0.5, 2)
317
+ lr = F.adjust_brightness(lr, factor)
318
+ hr = F.adjust_brightness(hr, factor)
319
+
320
+ return lr, hr
321
+
322
+
323
+ def random_adjust_contrast(lr: Any, hr: Any) -> [Any, Any]:
324
+ """Set ``PIL.Image`` to randomly adjust the image contrast.
325
+ Args:
326
+ lr: Low-resolution image data read by ``PIL.Image``.
327
+ hr: High-resolution image data read by ``PIL.Image``.
328
+ Returns:
329
+ Low-resolution image and high-resolution image with randomly adjusted contrast.
330
+ """
331
+ # Randomly adjust the contrast gain range.
332
+ factor = random.uniform(0.5, 2)
333
+ lr = F.adjust_contrast(lr, factor)
334
+ hr = F.adjust_contrast(hr, factor)
335
+
336
+ return lr, hr
337
+
338
+ #### metrics to compute -- assumes single images, i.e., tensor of 3 dims
339
+ def img_mae(x1, x2):
340
+ m = torch.abs(x1-x2).mean()
341
+ return m
342
+
343
+ def img_mse(x1, x2):
344
+ m = torch.pow(torch.abs(x1-x2),2).mean()
345
+ return m
346
+
347
+ def img_psnr(x1, x2):
348
+ m = kornia.metrics.psnr(x1, x2, 1)
349
+ return m
350
+
351
+ def img_ssim(x1, x2):
352
+ m = kornia.metrics.ssim(x1.unsqueeze(0), x2.unsqueeze(0), 5)
353
+ m = m.mean()
354
+ return m
355
+
356
+ def show_SR_w_uncer(xLR, xHR, xSR, xSRvar, elim=(0,0.01), ulim=(0,0.15)):
357
+ '''
358
+ xLR/SR/HR: 3xHxW
359
+ xSRvar: 1xHxW
360
+ '''
361
+ plt.figure(figsize=(30,10))
362
+
363
+ plt.subplot(1,5,1)
364
+ plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
365
+ plt.axis('off')
366
+
367
+ plt.subplot(1,5,2)
368
+ plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
369
+ plt.axis('off')
370
+
371
+ plt.subplot(1,5,3)
372
+ plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
373
+ plt.axis('off')
374
+
375
+ plt.subplot(1,5,4)
376
+ error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
377
+ print('error', error_map.min(), error_map.max())
378
+ plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
379
+ plt.clim(elim[0], elim[1])
380
+ plt.axis('off')
381
+
382
+ plt.subplot(1,5,5)
383
+ print('uncer', xSRvar.min(), xSRvar.max())
384
+ plt.imshow(xSRvar.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
385
+ plt.clim(ulim[0], ulim[1])
386
+ plt.axis('off')
387
+
388
+ plt.subplots_adjust(wspace=0, hspace=0)
389
+ plt.show()
390
+
391
+ def show_SR_w_err(xLR, xHR, xSR, elim=(0,0.01), task=None, xMask=None):
392
+ '''
393
+ xLR/SR/HR: 3xHxW
394
+ '''
395
+ plt.figure(figsize=(30,10))
396
+
397
+ if task != 'm':
398
+ plt.subplot(1,4,1)
399
+ plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
400
+ plt.axis('off')
401
+
402
+ plt.subplot(1,4,2)
403
+ plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
404
+ plt.axis('off')
405
+
406
+ plt.subplot(1,4,3)
407
+ plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
408
+ plt.axis('off')
409
+ else:
410
+ plt.subplot(1,4,1)
411
+ plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
412
+ plt.clim(0,0.9)
413
+ plt.axis('off')
414
+
415
+ plt.subplot(1,4,2)
416
+ plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
417
+ plt.clim(0,0.9)
418
+ plt.axis('off')
419
+
420
+ plt.subplot(1,4,3)
421
+ plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
422
+ plt.clim(0,0.9)
423
+ plt.axis('off')
424
+
425
+ plt.subplot(1,4,4)
426
+ if task == 'inpainting':
427
+ error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)*xMask.to('cpu').data
428
+ else:
429
+ error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
430
+ print('error', error_map.min(), error_map.max())
431
+ plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
432
+ plt.clim(elim[0], elim[1])
433
+ plt.axis('off')
434
+
435
+ plt.subplots_adjust(wspace=0, hspace=0)
436
+ plt.show()
437
+
438
+ def show_uncer4(xSRvar1, xSRvar2, xSRvar3, xSRvar4, ulim=(0,0.15)):
439
+ '''
440
+ xSRvar: 1xHxW
441
+ '''
442
+ plt.figure(figsize=(30,10))
443
+
444
+ plt.subplot(1,4,1)
445
+ print('uncer', xSRvar1.min(), xSRvar1.max())
446
+ plt.imshow(xSRvar1.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
447
+ plt.clim(ulim[0], ulim[1])
448
+ plt.axis('off')
449
+
450
+ plt.subplot(1,4,2)
451
+ print('uncer', xSRvar2.min(), xSRvar2.max())
452
+ plt.imshow(xSRvar2.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
453
+ plt.clim(ulim[0], ulim[1])
454
+ plt.axis('off')
455
+
456
+ plt.subplot(1,4,3)
457
+ print('uncer', xSRvar3.min(), xSRvar3.max())
458
+ plt.imshow(xSRvar3.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
459
+ plt.clim(ulim[0], ulim[1])
460
+ plt.axis('off')
461
+
462
+ plt.subplot(1,4,4)
463
+ print('uncer', xSRvar4.min(), xSRvar4.max())
464
+ plt.imshow(xSRvar4.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
465
+ plt.clim(ulim[0], ulim[1])
466
+ plt.axis('off')
467
+
468
+ plt.subplots_adjust(wspace=0, hspace=0)
469
+ plt.show()
470
+
471
+ def get_UCE(list_err, list_yout_var, num_bins=100):
472
+ err_min = np.min(list_err)
473
+ err_max = np.max(list_err)
474
+ err_len = (err_max-err_min)/num_bins
475
+ num_points = len(list_err)
476
+
477
+ bin_stats = {}
478
+ for i in range(num_bins):
479
+ bin_stats[i] = {
480
+ 'start_idx': err_min + i*err_len,
481
+ 'end_idx': err_min + (i+1)*err_len,
482
+ 'num_points': 0,
483
+ 'mean_err': 0,
484
+ 'mean_var': 0,
485
+ }
486
+
487
+ for e,v in zip(list_err, list_yout_var):
488
+ for i in range(num_bins):
489
+ if e>=bin_stats[i]['start_idx'] and e<bin_stats[i]['end_idx']:
490
+ bin_stats[i]['num_points'] += 1
491
+ bin_stats[i]['mean_err'] += e
492
+ bin_stats[i]['mean_var'] += v
493
+
494
+ uce = 0
495
+ eps = 1e-8
496
+ for i in range(num_bins):
497
+ bin_stats[i]['mean_err'] /= bin_stats[i]['num_points'] + eps
498
+ bin_stats[i]['mean_var'] /= bin_stats[i]['num_points'] + eps
499
+ bin_stats[i]['uce_bin'] = (bin_stats[i]['num_points']/num_points) \
500
+ *(np.abs(bin_stats[i]['mean_err'] - bin_stats[i]['mean_var']))
501
+ uce += bin_stats[i]['uce_bin']
502
+
503
+ list_x, list_y = [], []
504
+ for i in range(num_bins):
505
+ if bin_stats[i]['num_points']>0:
506
+ list_x.append(bin_stats[i]['mean_err'])
507
+ list_y.append(bin_stats[i]['mean_var'])
508
+
509
+ # sns.set_style('darkgrid')
510
+ # sns.scatterplot(x=list_x, y=list_y)
511
+ # sns.regplot(x=list_x, y=list_y, order=1)
512
+ # plt.xlabel('MSE', fontsize=34)
513
+ # plt.ylabel('Uncertainty', fontsize=34)
514
+ # plt.plot(list_x, list_x, color='r')
515
+ # plt.xlim(np.min(list_x), np.max(list_x))
516
+ # plt.ylim(np.min(list_err), np.max(list_x))
517
+ # plt.show()
518
+
519
+ return bin_stats, uce
520
+
521
+ ##################### training BayesCap
522
+ def train_BayesCap(
523
+ NetC,
524
+ NetG,
525
+ train_loader,
526
+ eval_loader,
527
+ Cri = TempCombLoss(),
528
+ device='cuda',
529
+ dtype=torch.cuda.FloatTensor(),
530
+ init_lr=1e-4,
531
+ num_epochs=100,
532
+ eval_every=1,
533
+ ckpt_path='../ckpt/BayesCap',
534
+ T1=1e0,
535
+ T2=5e-2,
536
+ task=None,
537
+ ):
538
+ NetC.to(device)
539
+ NetC.train()
540
+ NetG.to(device)
541
+ NetG.eval()
542
+ optimizer = torch.optim.Adam(list(NetC.parameters()), lr=init_lr)
543
+ optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
544
+
545
+ score = -1e8
546
+ all_loss = []
547
+ for eph in range(num_epochs):
548
+ eph_loss = 0
549
+ with tqdm(train_loader, unit='batch') as tepoch:
550
+ for (idx, batch) in enumerate(tepoch):
551
+ if idx>2000:
552
+ break
553
+ tepoch.set_description('Epoch {}'.format(eph))
554
+ ##
555
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
556
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
557
+ if task == 'inpainting':
558
+ xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
559
+ xMask = xMask.to(device).type(dtype)
560
+ # pass them through the network
561
+ with torch.no_grad():
562
+ if task == 'inpainting':
563
+ _, xSR1 = NetG(xLR, xMask)
564
+ elif task == 'depth':
565
+ xSR1 = NetG(xLR)[("disp", 0)]
566
+ else:
567
+ xSR1 = NetG(xLR)
568
+ # with torch.autograd.set_detect_anomaly(True):
569
+ xSR = xSR1.clone()
570
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
571
+ # print(xSRC_alpha)
572
+ optimizer.zero_grad()
573
+ if task == 'depth':
574
+ loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xSR, T1=T1, T2=T2)
575
+ else:
576
+ loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xHR, T1=T1, T2=T2)
577
+ # print(loss)
578
+ loss.backward()
579
+ optimizer.step()
580
+ ##
581
+ eph_loss += loss.item()
582
+ tepoch.set_postfix(loss=loss.item())
583
+ eph_loss /= len(train_loader)
584
+ all_loss.append(eph_loss)
585
+ print('Avg. loss: {}'.format(eph_loss))
586
+ # evaluate and save the models
587
+ torch.save(NetC.state_dict(), ckpt_path+'_last.pth')
588
+ if eph%eval_every == 0:
589
+ curr_score = eval_BayesCap(
590
+ NetC,
591
+ NetG,
592
+ eval_loader,
593
+ device=device,
594
+ dtype=dtype,
595
+ task=task,
596
+ )
597
+ print('current score: {} | Last best score: {}'.format(curr_score, score))
598
+ if curr_score >= score:
599
+ score = curr_score
600
+ torch.save(NetC.state_dict(), ckpt_path+'_best.pth')
601
+ optim_scheduler.step()
602
+
603
+ #### get different uncertainty maps
604
+ def get_uncer_BayesCap(
605
+ NetC,
606
+ NetG,
607
+ xin,
608
+ task=None,
609
+ xMask=None,
610
+ ):
611
+ with torch.no_grad():
612
+ if task == 'inpainting':
613
+ _, xSR = NetG(xin, xMask)
614
+ else:
615
+ xSR = NetG(xin)
616
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
617
+ a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
618
+ b_map = xSRC_beta.to('cpu').data
619
+ xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
620
+
621
+ return xSRvar
622
+
623
+ def get_uncer_TTDAp(
624
+ NetG,
625
+ xin,
626
+ p_mag=0.05,
627
+ num_runs=50,
628
+ task=None,
629
+ xMask=None,
630
+ ):
631
+ list_xSR = []
632
+ with torch.no_grad():
633
+ for z in range(num_runs):
634
+ if task == 'inpainting':
635
+ _, xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin), xMask)
636
+ else:
637
+ xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin))
638
+ list_xSR.append(xSRz)
639
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
640
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
641
+ return xSRvar
642
+
643
+ def get_uncer_DO(
644
+ NetG,
645
+ xin,
646
+ dop=0.2,
647
+ num_runs=50,
648
+ task=None,
649
+ xMask=None,
650
+ ):
651
+ list_xSR = []
652
+ with torch.no_grad():
653
+ for z in range(num_runs):
654
+ if task == 'inpainting':
655
+ _, xSRz = NetG(xin, xMask, dop=dop)
656
+ else:
657
+ xSRz = NetG(xin, dop=dop)
658
+ list_xSR.append(xSRz)
659
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
660
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
661
+ return xSRvar
662
+
663
+ ################### Different eval functions
664
+
665
+ def eval_BayesCap(
666
+ NetC,
667
+ NetG,
668
+ eval_loader,
669
+ device='cuda',
670
+ dtype=torch.cuda.FloatTensor,
671
+ task=None,
672
+ xMask=None,
673
+ ):
674
+ NetC.to(device)
675
+ NetC.eval()
676
+ NetG.to(device)
677
+ NetG.eval()
678
+
679
+ mean_ssim = 0
680
+ mean_psnr = 0
681
+ mean_mse = 0
682
+ mean_mae = 0
683
+ num_imgs = 0
684
+ list_error = []
685
+ list_var = []
686
+ with tqdm(eval_loader, unit='batch') as tepoch:
687
+ for (idx, batch) in enumerate(tepoch):
688
+ tepoch.set_description('Validating ...')
689
+ ##
690
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
691
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
692
+ if task == 'inpainting':
693
+ if xMask==None:
694
+ xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
695
+ xMask = xMask.to(device).type(dtype)
696
+ else:
697
+ xMask = xMask.to(device).type(dtype)
698
+ # pass them through the network
699
+ with torch.no_grad():
700
+ if task == 'inpainting':
701
+ _, xSR = NetG(xLR, xMask)
702
+ elif task == 'depth':
703
+ xSR = NetG(xLR)[("disp", 0)]
704
+ else:
705
+ xSR = NetG(xLR)
706
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
707
+ a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
708
+ b_map = xSRC_beta.to('cpu').data
709
+ xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
710
+ n_batch = xSRC_mu.shape[0]
711
+ if task == 'depth':
712
+ xHR = xSR
713
+ for j in range(n_batch):
714
+ num_imgs += 1
715
+ mean_ssim += img_ssim(xSRC_mu[j], xHR[j])
716
+ mean_psnr += img_psnr(xSRC_mu[j], xHR[j])
717
+ mean_mse += img_mse(xSRC_mu[j], xHR[j])
718
+ mean_mae += img_mae(xSRC_mu[j], xHR[j])
719
+
720
+ show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
721
+
722
+ error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
723
+ var_map = xSRvar[j].to('cpu').data.reshape(-1)
724
+ list_error.extend(list(error_map.numpy()))
725
+ list_var.extend(list(var_map.numpy()))
726
+ ##
727
+ mean_ssim /= num_imgs
728
+ mean_psnr /= num_imgs
729
+ mean_mse /= num_imgs
730
+ mean_mae /= num_imgs
731
+ print(
732
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
733
+ (
734
+ mean_ssim, mean_psnr, mean_mse, mean_mae
735
+ )
736
+ )
737
+ # print(len(list_error), len(list_var))
738
+ # print('UCE: ', get_UCE(list_error[::10], list_var[::10], num_bins=500)[1])
739
+ # print('C.Coeff: ', np.corrcoef(np.array(list_error[::10]), np.array(list_var[::10])))
740
+ return mean_ssim
741
+
742
+ def eval_TTDA_p(
743
+ NetG,
744
+ eval_loader,
745
+ device='cuda',
746
+ dtype=torch.cuda.FloatTensor,
747
+ p_mag=0.05,
748
+ num_runs=50,
749
+ task = None,
750
+ xMask = None,
751
+ ):
752
+ NetG.to(device)
753
+ NetG.eval()
754
+
755
+ mean_ssim = 0
756
+ mean_psnr = 0
757
+ mean_mse = 0
758
+ mean_mae = 0
759
+ num_imgs = 0
760
+ with tqdm(eval_loader, unit='batch') as tepoch:
761
+ for (idx, batch) in enumerate(tepoch):
762
+ tepoch.set_description('Validating ...')
763
+ ##
764
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
765
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
766
+ # pass them through the network
767
+ list_xSR = []
768
+ with torch.no_grad():
769
+ if task=='inpainting':
770
+ _, xSR = NetG(xLR, xMask)
771
+ else:
772
+ xSR = NetG(xLR)
773
+ for z in range(num_runs):
774
+ xSRz = NetG(xLR+p_mag*xLR.max()*torch.randn_like(xLR))
775
+ list_xSR.append(xSRz)
776
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
777
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
778
+ n_batch = xSR.shape[0]
779
+ for j in range(n_batch):
780
+ num_imgs += 1
781
+ mean_ssim += img_ssim(xSR[j], xHR[j])
782
+ mean_psnr += img_psnr(xSR[j], xHR[j])
783
+ mean_mse += img_mse(xSR[j], xHR[j])
784
+ mean_mae += img_mae(xSR[j], xHR[j])
785
+
786
+ show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
787
+
788
+ mean_ssim /= num_imgs
789
+ mean_psnr /= num_imgs
790
+ mean_mse /= num_imgs
791
+ mean_mae /= num_imgs
792
+ print(
793
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
794
+ (
795
+ mean_ssim, mean_psnr, mean_mse, mean_mae
796
+ )
797
+ )
798
+
799
+ return mean_ssim
800
+
801
+ def eval_DO(
802
+ NetG,
803
+ eval_loader,
804
+ device='cuda',
805
+ dtype=torch.cuda.FloatTensor,
806
+ dop=0.2,
807
+ num_runs=50,
808
+ task=None,
809
+ xMask=None,
810
+ ):
811
+ NetG.to(device)
812
+ NetG.eval()
813
+
814
+ mean_ssim = 0
815
+ mean_psnr = 0
816
+ mean_mse = 0
817
+ mean_mae = 0
818
+ num_imgs = 0
819
+ with tqdm(eval_loader, unit='batch') as tepoch:
820
+ for (idx, batch) in enumerate(tepoch):
821
+ tepoch.set_description('Validating ...')
822
+ ##
823
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
824
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
825
+ # pass them through the network
826
+ list_xSR = []
827
+ with torch.no_grad():
828
+ if task == 'inpainting':
829
+ _, xSR = NetG(xLR, xMask)
830
+ else:
831
+ xSR = NetG(xLR)
832
+ for z in range(num_runs):
833
+ xSRz = NetG(xLR, dop=dop)
834
+ list_xSR.append(xSRz)
835
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
836
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
837
+ n_batch = xSR.shape[0]
838
+ for j in range(n_batch):
839
+ num_imgs += 1
840
+ mean_ssim += img_ssim(xSR[j], xHR[j])
841
+ mean_psnr += img_psnr(xSR[j], xHR[j])
842
+ mean_mse += img_mse(xSR[j], xHR[j])
843
+ mean_mae += img_mae(xSR[j], xHR[j])
844
+
845
+ show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
846
+ ##
847
+ mean_ssim /= num_imgs
848
+ mean_psnr /= num_imgs
849
+ mean_mse /= num_imgs
850
+ mean_mae /= num_imgs
851
+ print(
852
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
853
+ (
854
+ mean_ssim, mean_psnr, mean_mse, mean_mae
855
+ )
856
+ )
857
+
858
+ return mean_ssim
859
+
860
+
861
+ ############### compare all function
862
+ def compare_all(
863
+ NetC,
864
+ NetG,
865
+ eval_loader,
866
+ p_mag = 0.05,
867
+ dop = 0.2,
868
+ num_runs = 100,
869
+ device='cuda',
870
+ dtype=torch.cuda.FloatTensor,
871
+ task=None,
872
+ ):
873
+ NetC.to(device)
874
+ NetC.eval()
875
+ NetG.to(device)
876
+ NetG.eval()
877
+
878
+ with tqdm(eval_loader, unit='batch') as tepoch:
879
+ for (idx, batch) in enumerate(tepoch):
880
+ tepoch.set_description('Comparing ...')
881
+ ##
882
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
883
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
884
+ if task == 'inpainting':
885
+ xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
886
+ xMask = xMask.to(device).type(dtype)
887
+ # pass them through the network
888
+ with torch.no_grad():
889
+ if task == 'inpainting':
890
+ _, xSR = NetG(xLR, xMask)
891
+ else:
892
+ xSR = NetG(xLR)
893
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
894
+
895
+ if task == 'inpainting':
896
+ xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs, task='inpainting', xMask=xMask)
897
+ xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs, task='inpainting', xMask=xMask)
898
+ xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR, task='inpainting', xMask=xMask)
899
+ else:
900
+ xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs)
901
+ xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs)
902
+ xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR)
903
+
904
+ print('bdg', xSRvar1.shape, xSRvar2.shape, xSRvar3.shape)
905
+
906
+ n_batch = xSR.shape[0]
907
+ for j in range(n_batch):
908
+ if task=='s':
909
+ show_SR_w_err(xLR[j], xHR[j], xSR[j])
910
+ show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
911
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j])
912
+ if task=='d':
913
+ show_SR_w_err(xLR[j], xHR[j], 0.5*xSR[j]+0.5*xHR[j])
914
+ show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
915
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
916
+ if task=='inpainting':
917
+ show_SR_w_err(xLR[j]*(1-xMask[j]), xHR[j], xSR[j], elim=(0,0.25), task='inpainting', xMask=xMask[j])
918
+ show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.45), torch.pow(xSRvar1[j], 0.4))
919
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
920
+ if task=='m':
921
+ show_SR_w_err(xLR[j], xHR[j], xSR[j], elim=(0,0.04), task='m')
922
+ show_uncer4(0.4*xSRvar1[j]+0.6*xSRvar2[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42), ulim=(0.02,0.15))
923
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j], ulim=(0.02,0.15))
924
+
925
+
926
+ ################# Degrading Identity
927
+ def degrage_BayesCap_p(
928
+ NetC,
929
+ NetG,
930
+ eval_loader,
931
+ device='cuda',
932
+ dtype=torch.cuda.FloatTensor,
933
+ num_runs=50,
934
+ ):
935
+ NetC.to(device)
936
+ NetC.eval()
937
+ NetG.to(device)
938
+ NetG.eval()
939
+
940
+ p_mag_list = [0, 0.05, 0.1, 0.15, 0.2]
941
+ list_s = []
942
+ list_p = []
943
+ list_u1 = []
944
+ list_u2 = []
945
+ list_c = []
946
+ for p_mag in p_mag_list:
947
+ mean_ssim = 0
948
+ mean_psnr = 0
949
+ mean_mse = 0
950
+ mean_mae = 0
951
+ num_imgs = 0
952
+ list_error = []
953
+ list_error2 = []
954
+ list_var = []
955
+
956
+ with tqdm(eval_loader, unit='batch') as tepoch:
957
+ for (idx, batch) in enumerate(tepoch):
958
+ tepoch.set_description('Validating ...')
959
+ ##
960
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
961
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
962
+ # pass them through the network
963
+ with torch.no_grad():
964
+ xSR = NetG(xLR)
965
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR + p_mag*xSR.max()*torch.randn_like(xSR))
966
+ a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
967
+ b_map = xSRC_beta.to('cpu').data
968
+ xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
969
+ n_batch = xSRC_mu.shape[0]
970
+ for j in range(n_batch):
971
+ num_imgs += 1
972
+ mean_ssim += img_ssim(xSRC_mu[j], xSR[j])
973
+ mean_psnr += img_psnr(xSRC_mu[j], xSR[j])
974
+ mean_mse += img_mse(xSRC_mu[j], xSR[j])
975
+ mean_mae += img_mae(xSRC_mu[j], xSR[j])
976
+
977
+ error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
978
+ error_map2 = torch.mean(torch.pow(torch.abs(xSRC_mu[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
979
+ var_map = xSRvar[j].to('cpu').data.reshape(-1)
980
+ list_error.extend(list(error_map.numpy()))
981
+ list_error2.extend(list(error_map2.numpy()))
982
+ list_var.extend(list(var_map.numpy()))
983
+ ##
984
+ mean_ssim /= num_imgs
985
+ mean_psnr /= num_imgs
986
+ mean_mse /= num_imgs
987
+ mean_mae /= num_imgs
988
+ print(
989
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
990
+ (
991
+ mean_ssim, mean_psnr, mean_mse, mean_mae
992
+ )
993
+ )
994
+ uce1 = get_UCE(list_error[::100], list_var[::100], num_bins=200)[1]
995
+ uce2 = get_UCE(list_error2[::100], list_var[::100], num_bins=200)[1]
996
+ print('UCE1: ', uce1)
997
+ print('UCE2: ', uce2)
998
+ list_s.append(mean_ssim.item())
999
+ list_p.append(mean_psnr.item())
1000
+ list_u1.append(uce1)
1001
+ list_u2.append(uce2)
1002
+
1003
+ plt.plot(list_s)
1004
+ plt.show()
1005
+ plt.plot(list_p)
1006
+ plt.show()
1007
+
1008
+ plt.plot(list_u1, label='wrt SR output')
1009
+ plt.plot(list_u2, label='wrt BayesCap output')
1010
+ plt.legend()
1011
+ plt.show()
1012
+
1013
+ sns.set_style('darkgrid')
1014
+ fig,ax = plt.subplots()
1015
+ # make a plot
1016
+ ax.plot(p_mag_list, list_s, color="red", marker="o")
1017
+ # set x-axis label
1018
+ ax.set_xlabel("Reducing faithfulness of BayesCap Reconstruction",fontsize=10)
1019
+ # set y-axis label
1020
+ ax.set_ylabel("SSIM btwn BayesCap and SRGAN outputs", color="red",fontsize=10)
1021
+
1022
+ # twin object for two different y-axis on the sample plot
1023
+ ax2=ax.twinx()
1024
+ # make a plot with different y-axis using second axis object
1025
+ ax2.plot(p_mag_list, list_u1, color="blue", marker="o", label='UCE wrt to error btwn SRGAN output and GT')
1026
+ ax2.plot(p_mag_list, list_u2, color="orange", marker="o", label='UCE wrt to error btwn BayesCap output and GT')
1027
+ ax2.set_ylabel("UCE", color="green", fontsize=10)
1028
+ plt.legend(fontsize=10)
1029
+ plt.tight_layout()
1030
+ plt.show()
1031
+
1032
+ ################# DeepFill_v2
1033
+
1034
+ # ----------------------------------------
1035
+ # PATH processing
1036
+ # ----------------------------------------
1037
+ def text_readlines(filename):
1038
+ # Try to read a txt file and return a list.Return [] if there was a mistake.
1039
+ try:
1040
+ file = open(filename, 'r')
1041
+ except IOError:
1042
+ error = []
1043
+ return error
1044
+ content = file.readlines()
1045
+ # This for loop deletes the EOF (like \n)
1046
+ for i in range(len(content)):
1047
+ content[i] = content[i][:len(content[i])-1]
1048
+ file.close()
1049
+ return content
1050
+
1051
+ def savetxt(name, loss_log):
1052
+ np_loss_log = np.array(loss_log)
1053
+ np.savetxt(name, np_loss_log)
1054
+
1055
+ def get_files(path):
1056
+ # read a folder, return the complete path
1057
+ ret = []
1058
+ for root, dirs, files in os.walk(path):
1059
+ for filespath in files:
1060
+ ret.append(os.path.join(root, filespath))
1061
+ return ret
1062
+
1063
+ def get_names(path):
1064
+ # read a folder, return the image name
1065
+ ret = []
1066
+ for root, dirs, files in os.walk(path):
1067
+ for filespath in files:
1068
+ ret.append(filespath)
1069
+ return ret
1070
+
1071
+ def text_save(content, filename, mode = 'a'):
1072
+ # save a list to a txt
1073
+ # Try to save a list variable in txt file.
1074
+ file = open(filename, mode)
1075
+ for i in range(len(content)):
1076
+ file.write(str(content[i]) + '\n')
1077
+ file.close()
1078
+
1079
+ def check_path(path):
1080
+ if not os.path.exists(path):
1081
+ os.makedirs(path)
1082
+
1083
+ # ----------------------------------------
1084
+ # Validation and Sample at training
1085
+ # ----------------------------------------
1086
+ def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255):
1087
+ # Save image one-by-one
1088
+ for i in range(len(img_list)):
1089
+ img = img_list[i]
1090
+ # Recover normalization: * 255 because last layer is sigmoid activated
1091
+ img = img * 255
1092
+ # Process img_copy and do not destroy the data of img
1093
+ img_copy = img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
1094
+ img_copy = np.clip(img_copy, 0, pixel_max_cnt)
1095
+ img_copy = img_copy.astype(np.uint8)
1096
+ img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
1097
+ # Save to certain path
1098
+ save_img_name = sample_name + '_' + name_list[i] + '.jpg'
1099
+ save_img_path = os.path.join(sample_folder, save_img_name)
1100
+ cv2.imwrite(save_img_path, img_copy)
1101
+
1102
+ def psnr(pred, target, pixel_max_cnt = 255):
1103
+ mse = torch.mul(target - pred, target - pred)
1104
+ rmse_avg = (torch.mean(mse).item()) ** 0.5
1105
+ p = 20 * np.log10(pixel_max_cnt / rmse_avg)
1106
+ return p
1107
+
1108
+ def grey_psnr(pred, target, pixel_max_cnt = 255):
1109
+ pred = torch.sum(pred, dim = 0)
1110
+ target = torch.sum(target, dim = 0)
1111
+ mse = torch.mul(target - pred, target - pred)
1112
+ rmse_avg = (torch.mean(mse).item()) ** 0.5
1113
+ p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
1114
+ return p
1115
+
1116
+ def ssim(pred, target):
1117
+ pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy()
1118
+ target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy()
1119
+ target = target[0]
1120
+ pred = pred[0]
1121
+ ssim = skimage.measure.compare_ssim(target, pred, multichannel = True)
1122
+ return ssim
1123
+
1124
+ ## for contextual attention
1125
+
1126
+ def extract_image_patches(images, ksizes, strides, rates, padding='same'):
1127
+ """
1128
+ Extract patches from images and put them in the C output dimension.
1129
+ :param padding:
1130
+ :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
1131
+ :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
1132
+ each dimension of images
1133
+ :param strides: [stride_rows, stride_cols]
1134
+ :param rates: [dilation_rows, dilation_cols]
1135
+ :return: A Tensor
1136
+ """
1137
+ assert len(images.size()) == 4
1138
+ assert padding in ['same', 'valid']
1139
+ batch_size, channel, height, width = images.size()
1140
+
1141
+ if padding == 'same':
1142
+ images = same_padding(images, ksizes, strides, rates)
1143
+ elif padding == 'valid':
1144
+ pass
1145
+ else:
1146
+ raise NotImplementedError('Unsupported padding type: {}.\
1147
+ Only "same" or "valid" are supported.'.format(padding))
1148
+
1149
+ unfold = torch.nn.Unfold(kernel_size=ksizes,
1150
+ dilation=rates,
1151
+ padding=0,
1152
+ stride=strides)
1153
+ patches = unfold(images)
1154
+ return patches # [N, C*k*k, L], L is the total number of such blocks
1155
+
1156
+ def same_padding(images, ksizes, strides, rates):
1157
+ assert len(images.size()) == 4
1158
+ batch_size, channel, rows, cols = images.size()
1159
+ out_rows = (rows + strides[0] - 1) // strides[0]
1160
+ out_cols = (cols + strides[1] - 1) // strides[1]
1161
+ effective_k_row = (ksizes[0] - 1) * rates[0] + 1
1162
+ effective_k_col = (ksizes[1] - 1) * rates[1] + 1
1163
+ padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
1164
+ padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
1165
+ # Pad the input
1166
+ padding_top = int(padding_rows / 2.)
1167
+ padding_left = int(padding_cols / 2.)
1168
+ padding_bottom = padding_rows - padding_top
1169
+ padding_right = padding_cols - padding_left
1170
+ paddings = (padding_left, padding_right, padding_top, padding_bottom)
1171
+ images = torch.nn.ZeroPad2d(paddings)(images)
1172
+ return images
1173
+
1174
+ def reduce_mean(x, axis=None, keepdim=False):
1175
+ if not axis:
1176
+ axis = range(len(x.shape))
1177
+ for i in sorted(axis, reverse=True):
1178
+ x = torch.mean(x, dim=i, keepdim=keepdim)
1179
+ return x
1180
+
1181
+
1182
+ def reduce_std(x, axis=None, keepdim=False):
1183
+ if not axis:
1184
+ axis = range(len(x.shape))
1185
+ for i in sorted(axis, reverse=True):
1186
+ x = torch.std(x, dim=i, keepdim=keepdim)
1187
+ return x
1188
+
1189
+
1190
+ def reduce_sum(x, axis=None, keepdim=False):
1191
+ if not axis:
1192
+ axis = range(len(x.shape))
1193
+ for i in sorted(axis, reverse=True):
1194
+ x = torch.sum(x, dim=i, keepdim=keepdim)
1195
+ return x
1196
+
1197
+ def random_mask(num_batch=1, mask_shape=(256,256)):
1198
+ list_mask = []
1199
+ for _ in range(num_batch):
1200
+ # rectangle mask
1201
+ image_height = mask_shape[0]
1202
+ image_width = mask_shape[1]
1203
+ max_delta_height = image_height//8
1204
+ max_delta_width = image_width//8
1205
+ height = image_height//4
1206
+ width = image_width//4
1207
+ max_t = image_height - height
1208
+ max_l = image_width - width
1209
+ t = random.randint(0, max_t)
1210
+ l = random.randint(0, max_l)
1211
+ # bbox = (t, l, height, width)
1212
+ h = random.randint(0, max_delta_height//2)
1213
+ w = random.randint(0, max_delta_width//2)
1214
+ mask = torch.zeros((1, 1, image_height, image_width))
1215
+ mask[:, :, t+h:t+height-h, l+w:l+width-w] = 1
1216
+ rect_mask = mask
1217
+
1218
+ # brush mask
1219
+ min_num_vertex = 4
1220
+ max_num_vertex = 12
1221
+ mean_angle = 2 * math.pi / 5
1222
+ angle_range = 2 * math.pi / 15
1223
+ min_width = 12
1224
+ max_width = 40
1225
+ H, W = image_height, image_width
1226
+ average_radius = math.sqrt(H*H+W*W) / 8
1227
+ mask = Image.new('L', (W, H), 0)
1228
+
1229
+ for _ in range(np.random.randint(1, 4)):
1230
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
1231
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
1232
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
1233
+ angles = []
1234
+ vertex = []
1235
+ for i in range(num_vertex):
1236
+ if i % 2 == 0:
1237
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
1238
+ else:
1239
+ angles.append(np.random.uniform(angle_min, angle_max))
1240
+
1241
+ h, w = mask.size
1242
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
1243
+ for i in range(num_vertex):
1244
+ r = np.clip(
1245
+ np.random.normal(loc=average_radius, scale=average_radius//2),
1246
+ 0, 2*average_radius)
1247
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
1248
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
1249
+ vertex.append((int(new_x), int(new_y)))
1250
+
1251
+ draw = ImageDraw.Draw(mask)
1252
+ width = int(np.random.uniform(min_width, max_width))
1253
+ draw.line(vertex, fill=255, width=width)
1254
+ for v in vertex:
1255
+ draw.ellipse((v[0] - width//2,
1256
+ v[1] - width//2,
1257
+ v[0] + width//2,
1258
+ v[1] + width//2),
1259
+ fill=255)
1260
+
1261
+ if np.random.normal() > 0:
1262
+ mask.transpose(Image.FLIP_LEFT_RIGHT)
1263
+ if np.random.normal() > 0:
1264
+ mask.transpose(Image.FLIP_TOP_BOTTOM)
1265
+
1266
+ mask = transforms.ToTensor()(mask)
1267
+ mask = mask.reshape((1, 1, H, W))
1268
+ brush_mask = mask
1269
+
1270
+ mask = torch.cat([rect_mask, brush_mask], dim=1).max(dim=1, keepdim=True)[0]
1271
+ list_mask.append(mask)
1272
+ mask = torch.cat(list_mask, dim=0)
1273
+ return mask
utils.py ADDED
@@ -0,0 +1,1304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Optional
3
+ import numpy as np
4
+ import os
5
+ import cv2
6
+ from glob import glob
7
+ from PIL import Image, ImageDraw
8
+ from tqdm import tqdm
9
+ import kornia
10
+ import matplotlib.pyplot as plt
11
+ import seaborn as sns
12
+ import albumentations as albu
13
+ import functools
14
+ import math
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from torch import Tensor
19
+ import torchvision as tv
20
+ import torchvision.models as models
21
+ from torchvision import transforms
22
+ from torchvision.transforms import functional as F
23
+ from losses import TempCombLoss
24
+
25
+
26
+ ######## for loading checkpoint from googledrive
27
+ google_drive_paths = {
28
+ "BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL",
29
+ "BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9",
30
+ }
31
+
32
+ def ensure_checkpoint_exists(model_weights_filename):
33
+ if not os.path.isfile(model_weights_filename) and (
34
+ model_weights_filename in google_drive_paths
35
+ ):
36
+ gdrive_url = google_drive_paths[model_weights_filename]
37
+ try:
38
+ from gdown import download as drive_download
39
+
40
+ drive_download(gdrive_url, model_weights_filename, quiet=False)
41
+ except ModuleNotFoundError:
42
+ print(
43
+ "gdown module not found.",
44
+ "pip3 install gdown or, manually download the checkpoint file:",
45
+ gdrive_url
46
+ )
47
+
48
+ if not os.path.isfile(model_weights_filename) and (
49
+ model_weights_filename not in google_drive_paths
50
+ ):
51
+ print(
52
+ model_weights_filename,
53
+ " not found, you may need to manually download the model weights."
54
+ )
55
+
56
+ ########### DeblurGAN function
57
+ def get_norm_layer(norm_type='instance'):
58
+ if norm_type == 'batch':
59
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
60
+ elif norm_type == 'instance':
61
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
62
+ else:
63
+ raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
64
+ return norm_layer
65
+
66
+ def _array_to_batch(x):
67
+ x = np.transpose(x, (2, 0, 1))
68
+ x = np.expand_dims(x, 0)
69
+ return torch.from_numpy(x)
70
+
71
+ def get_normalize():
72
+ normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
73
+ normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
74
+
75
+ def process(a, b):
76
+ r = normalize(image=a, target=b)
77
+ return r['image'], r['target']
78
+
79
+ return process
80
+
81
+ def preprocess(x: np.ndarray, mask: Optional[np.ndarray]):
82
+ x, _ = get_normalize()(x, x)
83
+ if mask is None:
84
+ mask = np.ones_like(x, dtype=np.float32)
85
+ else:
86
+ mask = np.round(mask.astype('float32') / 255)
87
+
88
+ h, w, _ = x.shape
89
+ block_size = 32
90
+ min_height = (h // block_size + 1) * block_size
91
+ min_width = (w // block_size + 1) * block_size
92
+
93
+ pad_params = {'mode': 'constant',
94
+ 'constant_values': 0,
95
+ 'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
96
+ }
97
+ x = np.pad(x, **pad_params)
98
+ mask = np.pad(mask, **pad_params)
99
+
100
+ return map(_array_to_batch, (x, mask)), h, w
101
+
102
+ def postprocess(x: torch.Tensor) -> np.ndarray:
103
+ x, = x
104
+ x = x.detach().cpu().float().numpy()
105
+ x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
106
+ return x.astype('uint8')
107
+
108
+ def sorted_glob(pattern):
109
+ return sorted(glob(pattern))
110
+ ###########
111
+
112
+ def normalize(image: np.ndarray) -> np.ndarray:
113
+ """Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
114
+ Args:
115
+ image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
116
+ Returns:
117
+ Normalized image data. Data range [0, 1].
118
+ """
119
+ return image.astype(np.float64) / 255.0
120
+
121
+
122
+ def unnormalize(image: np.ndarray) -> np.ndarray:
123
+ """Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
124
+ Args:
125
+ image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
126
+ Returns:
127
+ Denormalized image data. Data range [0, 255].
128
+ """
129
+ return image.astype(np.float64) * 255.0
130
+
131
+
132
+ def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
133
+ """Convert ``PIL.Image`` to Tensor.
134
+ Args:
135
+ image (np.ndarray): The image data read by ``PIL.Image``
136
+ range_norm (bool): Scale [0, 1] data to between [-1, 1]
137
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
138
+ Returns:
139
+ Normalized image data
140
+ Examples:
141
+ >>> image = Image.open("image.bmp")
142
+ >>> tensor_image = image2tensor(image, range_norm=False, half=False)
143
+ """
144
+ tensor = F.to_tensor(image)
145
+
146
+ if range_norm:
147
+ tensor = tensor.mul_(2.0).sub_(1.0)
148
+ if half:
149
+ tensor = tensor.half()
150
+
151
+ return tensor
152
+
153
+
154
+ def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
155
+ """Converts ``torch.Tensor`` to ``PIL.Image``.
156
+ Args:
157
+ tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
158
+ range_norm (bool): Scale [-1, 1] data to between [0, 1]
159
+ half (bool): Whether to convert torch.float32 similarly to torch.half type.
160
+ Returns:
161
+ Convert image data to support PIL library
162
+ Examples:
163
+ >>> tensor = torch.randn([1, 3, 128, 128])
164
+ >>> image = tensor2image(tensor, range_norm=False, half=False)
165
+ """
166
+ if range_norm:
167
+ tensor = tensor.add_(1.0).div_(2.0)
168
+ if half:
169
+ tensor = tensor.half()
170
+
171
+ image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
172
+
173
+ return image
174
+
175
+
176
+ def convert_rgb_to_y(image: Any) -> Any:
177
+ """Convert RGB image or tensor image data to YCbCr(Y) format.
178
+ Args:
179
+ image: RGB image data read by ``PIL.Image''.
180
+ Returns:
181
+ Y image array data.
182
+ """
183
+ if type(image) == np.ndarray:
184
+ return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
185
+ elif type(image) == torch.Tensor:
186
+ if len(image.shape) == 4:
187
+ image = image.squeeze_(0)
188
+ return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
189
+ else:
190
+ raise Exception("Unknown Type", type(image))
191
+
192
+
193
+ def convert_rgb_to_ycbcr(image: Any) -> Any:
194
+ """Convert RGB image or tensor image data to YCbCr format.
195
+ Args:
196
+ image: RGB image data read by ``PIL.Image''.
197
+ Returns:
198
+ YCbCr image array data.
199
+ """
200
+ if type(image) == np.ndarray:
201
+ y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
202
+ cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256.
203
+ cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256.
204
+ return np.array([y, cb, cr]).transpose([1, 2, 0])
205
+ elif type(image) == torch.Tensor:
206
+ if len(image.shape) == 4:
207
+ image = image.squeeze(0)
208
+ y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
209
+ cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256.
210
+ cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256.
211
+ return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
212
+ else:
213
+ raise Exception("Unknown Type", type(image))
214
+
215
+
216
+ def convert_ycbcr_to_rgb(image: Any) -> Any:
217
+ """Convert YCbCr format image to RGB format.
218
+ Args:
219
+ image: YCbCr image data read by ``PIL.Image''.
220
+ Returns:
221
+ RGB image array data.
222
+ """
223
+ if type(image) == np.ndarray:
224
+ r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921
225
+ g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576
226
+ b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836
227
+ return np.array([r, g, b]).transpose([1, 2, 0])
228
+ elif type(image) == torch.Tensor:
229
+ if len(image.shape) == 4:
230
+ image = image.squeeze(0)
231
+ r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921
232
+ g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576
233
+ b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836
234
+ return torch.cat([r, g, b], 0).permute(1, 2, 0)
235
+ else:
236
+ raise Exception("Unknown Type", type(image))
237
+
238
+
239
+ def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
240
+ """Cut ``PIL.Image`` in the center area of the image.
241
+ Args:
242
+ lr: Low-resolution image data read by ``PIL.Image``.
243
+ hr: High-resolution image data read by ``PIL.Image``.
244
+ image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
245
+ upscale_factor (int): magnification factor.
246
+ Returns:
247
+ Randomly cropped low-resolution images and high-resolution images.
248
+ """
249
+ w, h = hr.size
250
+
251
+ left = (w - image_size) // 2
252
+ top = (h - image_size) // 2
253
+ right = left + image_size
254
+ bottom = top + image_size
255
+
256
+ lr = lr.crop((left // upscale_factor,
257
+ top // upscale_factor,
258
+ right // upscale_factor,
259
+ bottom // upscale_factor))
260
+ hr = hr.crop((left, top, right, bottom))
261
+
262
+ return lr, hr
263
+
264
+
265
+ def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
266
+ """Will ``PIL.Image`` randomly capture the specified area of the image.
267
+ Args:
268
+ lr: Low-resolution image data read by ``PIL.Image``.
269
+ hr: High-resolution image data read by ``PIL.Image``.
270
+ image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
271
+ upscale_factor (int): magnification factor.
272
+ Returns:
273
+ Randomly cropped low-resolution images and high-resolution images.
274
+ """
275
+ w, h = hr.size
276
+ left = torch.randint(0, w - image_size + 1, size=(1,)).item()
277
+ top = torch.randint(0, h - image_size + 1, size=(1,)).item()
278
+ right = left + image_size
279
+ bottom = top + image_size
280
+
281
+ lr = lr.crop((left // upscale_factor,
282
+ top // upscale_factor,
283
+ right // upscale_factor,
284
+ bottom // upscale_factor))
285
+ hr = hr.crop((left, top, right, bottom))
286
+
287
+ return lr, hr
288
+
289
+
290
+ def random_rotate(lr: Any, hr: Any, angle: int) -> [Any, Any]:
291
+ """Will ``PIL.Image`` randomly rotate the image.
292
+ Args:
293
+ lr: Low-resolution image data read by ``PIL.Image``.
294
+ hr: High-resolution image data read by ``PIL.Image``.
295
+ angle (int): rotation angle, clockwise and counterclockwise rotation.
296
+ Returns:
297
+ Randomly rotated low-resolution images and high-resolution images.
298
+ """
299
+ angle = random.choice((+angle, -angle))
300
+ lr = F.rotate(lr, angle)
301
+ hr = F.rotate(hr, angle)
302
+
303
+ return lr, hr
304
+
305
+
306
+ def random_horizontally_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
307
+ """Flip the ``PIL.Image`` image horizontally randomly.
308
+ Args:
309
+ lr: Low-resolution image data read by ``PIL.Image``.
310
+ hr: High-resolution image data read by ``PIL.Image``.
311
+ p (optional, float): rollover probability. (Default: 0.5)
312
+ Returns:
313
+ Low-resolution image and high-resolution image after random horizontal flip.
314
+ """
315
+ if torch.rand(1).item() > p:
316
+ lr = F.hflip(lr)
317
+ hr = F.hflip(hr)
318
+
319
+ return lr, hr
320
+
321
+
322
+ def random_vertically_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
323
+ """Turn the ``PIL.Image`` image upside down randomly.
324
+ Args:
325
+ lr: Low-resolution image data read by ``PIL.Image``.
326
+ hr: High-resolution image data read by ``PIL.Image``.
327
+ p (optional, float): rollover probability. (Default: 0.5)
328
+ Returns:
329
+ Randomly rotated up and down low-resolution images and high-resolution images.
330
+ """
331
+ if torch.rand(1).item() > p:
332
+ lr = F.vflip(lr)
333
+ hr = F.vflip(hr)
334
+
335
+ return lr, hr
336
+
337
+
338
+ def random_adjust_brightness(lr: Any, hr: Any) -> [Any, Any]:
339
+ """Set ``PIL.Image`` to randomly adjust the image brightness.
340
+ Args:
341
+ lr: Low-resolution image data read by ``PIL.Image``.
342
+ hr: High-resolution image data read by ``PIL.Image``.
343
+ Returns:
344
+ Low-resolution image and high-resolution image with randomly adjusted brightness.
345
+ """
346
+ # Randomly adjust the brightness gain range.
347
+ factor = random.uniform(0.5, 2)
348
+ lr = F.adjust_brightness(lr, factor)
349
+ hr = F.adjust_brightness(hr, factor)
350
+
351
+ return lr, hr
352
+
353
+
354
+ def random_adjust_contrast(lr: Any, hr: Any) -> [Any, Any]:
355
+ """Set ``PIL.Image`` to randomly adjust the image contrast.
356
+ Args:
357
+ lr: Low-resolution image data read by ``PIL.Image``.
358
+ hr: High-resolution image data read by ``PIL.Image``.
359
+ Returns:
360
+ Low-resolution image and high-resolution image with randomly adjusted contrast.
361
+ """
362
+ # Randomly adjust the contrast gain range.
363
+ factor = random.uniform(0.5, 2)
364
+ lr = F.adjust_contrast(lr, factor)
365
+ hr = F.adjust_contrast(hr, factor)
366
+
367
+ return lr, hr
368
+
369
+ #### metrics to compute -- assumes single images, i.e., tensor of 3 dims
370
+ def img_mae(x1, x2):
371
+ m = torch.abs(x1-x2).mean()
372
+ return m
373
+
374
+ def img_mse(x1, x2):
375
+ m = torch.pow(torch.abs(x1-x2),2).mean()
376
+ return m
377
+
378
+ def img_psnr(x1, x2):
379
+ m = kornia.metrics.psnr(x1, x2, 1)
380
+ return m
381
+
382
+ def img_ssim(x1, x2):
383
+ m = kornia.metrics.ssim(x1.unsqueeze(0), x2.unsqueeze(0), 5)
384
+ m = m.mean()
385
+ return m
386
+
387
+ def show_SR_w_uncer(xLR, xHR, xSR, xSRvar, elim=(0,0.01), ulim=(0,0.15)):
388
+ '''
389
+ xLR/SR/HR: 3xHxW
390
+ xSRvar: 1xHxW
391
+ '''
392
+ plt.figure(figsize=(30,10))
393
+
394
+ plt.subplot(1,5,1)
395
+ plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
396
+ plt.axis('off')
397
+
398
+ plt.subplot(1,5,2)
399
+ plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
400
+ plt.axis('off')
401
+
402
+ plt.subplot(1,5,3)
403
+ plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
404
+ plt.axis('off')
405
+
406
+ plt.subplot(1,5,4)
407
+ error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
408
+ print('error', error_map.min(), error_map.max())
409
+ plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
410
+ plt.clim(elim[0], elim[1])
411
+ plt.axis('off')
412
+
413
+ plt.subplot(1,5,5)
414
+ print('uncer', xSRvar.min(), xSRvar.max())
415
+ plt.imshow(xSRvar.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
416
+ plt.clim(ulim[0], ulim[1])
417
+ plt.axis('off')
418
+
419
+ plt.subplots_adjust(wspace=0, hspace=0)
420
+ plt.show()
421
+
422
+ def show_SR_w_err(xLR, xHR, xSR, elim=(0,0.01), task=None, xMask=None):
423
+ '''
424
+ xLR/SR/HR: 3xHxW
425
+ '''
426
+ plt.figure(figsize=(30,10))
427
+
428
+ if task != 'm':
429
+ plt.subplot(1,4,1)
430
+ plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
431
+ plt.axis('off')
432
+
433
+ plt.subplot(1,4,2)
434
+ plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
435
+ plt.axis('off')
436
+
437
+ plt.subplot(1,4,3)
438
+ plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
439
+ plt.axis('off')
440
+ else:
441
+ plt.subplot(1,4,1)
442
+ plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
443
+ plt.clim(0,0.9)
444
+ plt.axis('off')
445
+
446
+ plt.subplot(1,4,2)
447
+ plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
448
+ plt.clim(0,0.9)
449
+ plt.axis('off')
450
+
451
+ plt.subplot(1,4,3)
452
+ plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
453
+ plt.clim(0,0.9)
454
+ plt.axis('off')
455
+
456
+ plt.subplot(1,4,4)
457
+ if task == 'inpainting':
458
+ error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)*xMask.to('cpu').data
459
+ else:
460
+ error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
461
+ print('error', error_map.min(), error_map.max())
462
+ plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
463
+ plt.clim(elim[0], elim[1])
464
+ plt.axis('off')
465
+
466
+ plt.subplots_adjust(wspace=0, hspace=0)
467
+ plt.show()
468
+
469
+ def show_uncer4(xSRvar1, xSRvar2, xSRvar3, xSRvar4, ulim=(0,0.15)):
470
+ '''
471
+ xSRvar: 1xHxW
472
+ '''
473
+ plt.figure(figsize=(30,10))
474
+
475
+ plt.subplot(1,4,1)
476
+ print('uncer', xSRvar1.min(), xSRvar1.max())
477
+ plt.imshow(xSRvar1.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
478
+ plt.clim(ulim[0], ulim[1])
479
+ plt.axis('off')
480
+
481
+ plt.subplot(1,4,2)
482
+ print('uncer', xSRvar2.min(), xSRvar2.max())
483
+ plt.imshow(xSRvar2.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
484
+ plt.clim(ulim[0], ulim[1])
485
+ plt.axis('off')
486
+
487
+ plt.subplot(1,4,3)
488
+ print('uncer', xSRvar3.min(), xSRvar3.max())
489
+ plt.imshow(xSRvar3.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
490
+ plt.clim(ulim[0], ulim[1])
491
+ plt.axis('off')
492
+
493
+ plt.subplot(1,4,4)
494
+ print('uncer', xSRvar4.min(), xSRvar4.max())
495
+ plt.imshow(xSRvar4.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
496
+ plt.clim(ulim[0], ulim[1])
497
+ plt.axis('off')
498
+
499
+ plt.subplots_adjust(wspace=0, hspace=0)
500
+ plt.show()
501
+
502
+ def get_UCE(list_err, list_yout_var, num_bins=100):
503
+ err_min = np.min(list_err)
504
+ err_max = np.max(list_err)
505
+ err_len = (err_max-err_min)/num_bins
506
+ num_points = len(list_err)
507
+
508
+ bin_stats = {}
509
+ for i in range(num_bins):
510
+ bin_stats[i] = {
511
+ 'start_idx': err_min + i*err_len,
512
+ 'end_idx': err_min + (i+1)*err_len,
513
+ 'num_points': 0,
514
+ 'mean_err': 0,
515
+ 'mean_var': 0,
516
+ }
517
+
518
+ for e,v in zip(list_err, list_yout_var):
519
+ for i in range(num_bins):
520
+ if e>=bin_stats[i]['start_idx'] and e<bin_stats[i]['end_idx']:
521
+ bin_stats[i]['num_points'] += 1
522
+ bin_stats[i]['mean_err'] += e
523
+ bin_stats[i]['mean_var'] += v
524
+
525
+ uce = 0
526
+ eps = 1e-8
527
+ for i in range(num_bins):
528
+ bin_stats[i]['mean_err'] /= bin_stats[i]['num_points'] + eps
529
+ bin_stats[i]['mean_var'] /= bin_stats[i]['num_points'] + eps
530
+ bin_stats[i]['uce_bin'] = (bin_stats[i]['num_points']/num_points) \
531
+ *(np.abs(bin_stats[i]['mean_err'] - bin_stats[i]['mean_var']))
532
+ uce += bin_stats[i]['uce_bin']
533
+
534
+ list_x, list_y = [], []
535
+ for i in range(num_bins):
536
+ if bin_stats[i]['num_points']>0:
537
+ list_x.append(bin_stats[i]['mean_err'])
538
+ list_y.append(bin_stats[i]['mean_var'])
539
+
540
+ # sns.set_style('darkgrid')
541
+ # sns.scatterplot(x=list_x, y=list_y)
542
+ # sns.regplot(x=list_x, y=list_y, order=1)
543
+ # plt.xlabel('MSE', fontsize=34)
544
+ # plt.ylabel('Uncertainty', fontsize=34)
545
+ # plt.plot(list_x, list_x, color='r')
546
+ # plt.xlim(np.min(list_x), np.max(list_x))
547
+ # plt.ylim(np.min(list_err), np.max(list_x))
548
+ # plt.show()
549
+
550
+ return bin_stats, uce
551
+
552
+ ##################### training BayesCap
553
+ def train_BayesCap(
554
+ NetC,
555
+ NetG,
556
+ train_loader,
557
+ eval_loader,
558
+ Cri = TempCombLoss(),
559
+ device='cuda',
560
+ dtype=torch.cuda.FloatTensor(),
561
+ init_lr=1e-4,
562
+ num_epochs=100,
563
+ eval_every=1,
564
+ ckpt_path='../ckpt/BayesCap',
565
+ T1=1e0,
566
+ T2=5e-2,
567
+ task=None,
568
+ ):
569
+ NetC.to(device)
570
+ NetC.train()
571
+ NetG.to(device)
572
+ NetG.eval()
573
+ optimizer = torch.optim.Adam(list(NetC.parameters()), lr=init_lr)
574
+ optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
575
+
576
+ score = -1e8
577
+ all_loss = []
578
+ for eph in range(num_epochs):
579
+ eph_loss = 0
580
+ with tqdm(train_loader, unit='batch') as tepoch:
581
+ for (idx, batch) in enumerate(tepoch):
582
+ if idx>2000:
583
+ break
584
+ tepoch.set_description('Epoch {}'.format(eph))
585
+ ##
586
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
587
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
588
+ if task == 'inpainting':
589
+ xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
590
+ xMask = xMask.to(device).type(dtype)
591
+ # pass them through the network
592
+ with torch.no_grad():
593
+ if task == 'inpainting':
594
+ _, xSR1 = NetG(xLR, xMask)
595
+ elif task == 'depth':
596
+ xSR1 = NetG(xLR)[("disp", 0)]
597
+ else:
598
+ xSR1 = NetG(xLR)
599
+ # with torch.autograd.set_detect_anomaly(True):
600
+ xSR = xSR1.clone()
601
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
602
+ # print(xSRC_alpha)
603
+ optimizer.zero_grad()
604
+ if task == 'depth':
605
+ loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xSR, T1=T1, T2=T2)
606
+ else:
607
+ loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xHR, T1=T1, T2=T2)
608
+ # print(loss)
609
+ loss.backward()
610
+ optimizer.step()
611
+ ##
612
+ eph_loss += loss.item()
613
+ tepoch.set_postfix(loss=loss.item())
614
+ eph_loss /= len(train_loader)
615
+ all_loss.append(eph_loss)
616
+ print('Avg. loss: {}'.format(eph_loss))
617
+ # evaluate and save the models
618
+ torch.save(NetC.state_dict(), ckpt_path+'_last.pth')
619
+ if eph%eval_every == 0:
620
+ curr_score = eval_BayesCap(
621
+ NetC,
622
+ NetG,
623
+ eval_loader,
624
+ device=device,
625
+ dtype=dtype,
626
+ task=task,
627
+ )
628
+ print('current score: {} | Last best score: {}'.format(curr_score, score))
629
+ if curr_score >= score:
630
+ score = curr_score
631
+ torch.save(NetC.state_dict(), ckpt_path+'_best.pth')
632
+ optim_scheduler.step()
633
+
634
+ #### get different uncertainty maps
635
+ def get_uncer_BayesCap(
636
+ NetC,
637
+ NetG,
638
+ xin,
639
+ task=None,
640
+ xMask=None,
641
+ ):
642
+ with torch.no_grad():
643
+ if task == 'inpainting':
644
+ _, xSR = NetG(xin, xMask)
645
+ else:
646
+ xSR = NetG(xin)
647
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
648
+ a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
649
+ b_map = xSRC_beta.to('cpu').data
650
+ xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
651
+
652
+ return xSRvar
653
+
654
+ def get_uncer_TTDAp(
655
+ NetG,
656
+ xin,
657
+ p_mag=0.05,
658
+ num_runs=50,
659
+ task=None,
660
+ xMask=None,
661
+ ):
662
+ list_xSR = []
663
+ with torch.no_grad():
664
+ for z in range(num_runs):
665
+ if task == 'inpainting':
666
+ _, xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin), xMask)
667
+ else:
668
+ xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin))
669
+ list_xSR.append(xSRz)
670
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
671
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
672
+ return xSRvar
673
+
674
+ def get_uncer_DO(
675
+ NetG,
676
+ xin,
677
+ dop=0.2,
678
+ num_runs=50,
679
+ task=None,
680
+ xMask=None,
681
+ ):
682
+ list_xSR = []
683
+ with torch.no_grad():
684
+ for z in range(num_runs):
685
+ if task == 'inpainting':
686
+ _, xSRz = NetG(xin, xMask, dop=dop)
687
+ else:
688
+ xSRz = NetG(xin, dop=dop)
689
+ list_xSR.append(xSRz)
690
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
691
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
692
+ return xSRvar
693
+
694
+ ################### Different eval functions
695
+
696
+ def eval_BayesCap(
697
+ NetC,
698
+ NetG,
699
+ eval_loader,
700
+ device='cuda',
701
+ dtype=torch.cuda.FloatTensor,
702
+ task=None,
703
+ xMask=None,
704
+ ):
705
+ NetC.to(device)
706
+ NetC.eval()
707
+ NetG.to(device)
708
+ NetG.eval()
709
+
710
+ mean_ssim = 0
711
+ mean_psnr = 0
712
+ mean_mse = 0
713
+ mean_mae = 0
714
+ num_imgs = 0
715
+ list_error = []
716
+ list_var = []
717
+ with tqdm(eval_loader, unit='batch') as tepoch:
718
+ for (idx, batch) in enumerate(tepoch):
719
+ tepoch.set_description('Validating ...')
720
+ ##
721
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
722
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
723
+ if task == 'inpainting':
724
+ if xMask==None:
725
+ xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
726
+ xMask = xMask.to(device).type(dtype)
727
+ else:
728
+ xMask = xMask.to(device).type(dtype)
729
+ # pass them through the network
730
+ with torch.no_grad():
731
+ if task == 'inpainting':
732
+ _, xSR = NetG(xLR, xMask)
733
+ elif task == 'depth':
734
+ xSR = NetG(xLR)[("disp", 0)]
735
+ else:
736
+ xSR = NetG(xLR)
737
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
738
+ a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
739
+ b_map = xSRC_beta.to('cpu').data
740
+ xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
741
+ n_batch = xSRC_mu.shape[0]
742
+ if task == 'depth':
743
+ xHR = xSR
744
+ for j in range(n_batch):
745
+ num_imgs += 1
746
+ mean_ssim += img_ssim(xSRC_mu[j], xHR[j])
747
+ mean_psnr += img_psnr(xSRC_mu[j], xHR[j])
748
+ mean_mse += img_mse(xSRC_mu[j], xHR[j])
749
+ mean_mae += img_mae(xSRC_mu[j], xHR[j])
750
+
751
+ show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
752
+
753
+ error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
754
+ var_map = xSRvar[j].to('cpu').data.reshape(-1)
755
+ list_error.extend(list(error_map.numpy()))
756
+ list_var.extend(list(var_map.numpy()))
757
+ ##
758
+ mean_ssim /= num_imgs
759
+ mean_psnr /= num_imgs
760
+ mean_mse /= num_imgs
761
+ mean_mae /= num_imgs
762
+ print(
763
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
764
+ (
765
+ mean_ssim, mean_psnr, mean_mse, mean_mae
766
+ )
767
+ )
768
+ # print(len(list_error), len(list_var))
769
+ # print('UCE: ', get_UCE(list_error[::10], list_var[::10], num_bins=500)[1])
770
+ # print('C.Coeff: ', np.corrcoef(np.array(list_error[::10]), np.array(list_var[::10])))
771
+ return mean_ssim
772
+
773
+ def eval_TTDA_p(
774
+ NetG,
775
+ eval_loader,
776
+ device='cuda',
777
+ dtype=torch.cuda.FloatTensor,
778
+ p_mag=0.05,
779
+ num_runs=50,
780
+ task = None,
781
+ xMask = None,
782
+ ):
783
+ NetG.to(device)
784
+ NetG.eval()
785
+
786
+ mean_ssim = 0
787
+ mean_psnr = 0
788
+ mean_mse = 0
789
+ mean_mae = 0
790
+ num_imgs = 0
791
+ with tqdm(eval_loader, unit='batch') as tepoch:
792
+ for (idx, batch) in enumerate(tepoch):
793
+ tepoch.set_description('Validating ...')
794
+ ##
795
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
796
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
797
+ # pass them through the network
798
+ list_xSR = []
799
+ with torch.no_grad():
800
+ if task=='inpainting':
801
+ _, xSR = NetG(xLR, xMask)
802
+ else:
803
+ xSR = NetG(xLR)
804
+ for z in range(num_runs):
805
+ xSRz = NetG(xLR+p_mag*xLR.max()*torch.randn_like(xLR))
806
+ list_xSR.append(xSRz)
807
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
808
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
809
+ n_batch = xSR.shape[0]
810
+ for j in range(n_batch):
811
+ num_imgs += 1
812
+ mean_ssim += img_ssim(xSR[j], xHR[j])
813
+ mean_psnr += img_psnr(xSR[j], xHR[j])
814
+ mean_mse += img_mse(xSR[j], xHR[j])
815
+ mean_mae += img_mae(xSR[j], xHR[j])
816
+
817
+ show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
818
+
819
+ mean_ssim /= num_imgs
820
+ mean_psnr /= num_imgs
821
+ mean_mse /= num_imgs
822
+ mean_mae /= num_imgs
823
+ print(
824
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
825
+ (
826
+ mean_ssim, mean_psnr, mean_mse, mean_mae
827
+ )
828
+ )
829
+
830
+ return mean_ssim
831
+
832
+ def eval_DO(
833
+ NetG,
834
+ eval_loader,
835
+ device='cuda',
836
+ dtype=torch.cuda.FloatTensor,
837
+ dop=0.2,
838
+ num_runs=50,
839
+ task=None,
840
+ xMask=None,
841
+ ):
842
+ NetG.to(device)
843
+ NetG.eval()
844
+
845
+ mean_ssim = 0
846
+ mean_psnr = 0
847
+ mean_mse = 0
848
+ mean_mae = 0
849
+ num_imgs = 0
850
+ with tqdm(eval_loader, unit='batch') as tepoch:
851
+ for (idx, batch) in enumerate(tepoch):
852
+ tepoch.set_description('Validating ...')
853
+ ##
854
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
855
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
856
+ # pass them through the network
857
+ list_xSR = []
858
+ with torch.no_grad():
859
+ if task == 'inpainting':
860
+ _, xSR = NetG(xLR, xMask)
861
+ else:
862
+ xSR = NetG(xLR)
863
+ for z in range(num_runs):
864
+ xSRz = NetG(xLR, dop=dop)
865
+ list_xSR.append(xSRz)
866
+ xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
867
+ xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
868
+ n_batch = xSR.shape[0]
869
+ for j in range(n_batch):
870
+ num_imgs += 1
871
+ mean_ssim += img_ssim(xSR[j], xHR[j])
872
+ mean_psnr += img_psnr(xSR[j], xHR[j])
873
+ mean_mse += img_mse(xSR[j], xHR[j])
874
+ mean_mae += img_mae(xSR[j], xHR[j])
875
+
876
+ show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
877
+ ##
878
+ mean_ssim /= num_imgs
879
+ mean_psnr /= num_imgs
880
+ mean_mse /= num_imgs
881
+ mean_mae /= num_imgs
882
+ print(
883
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
884
+ (
885
+ mean_ssim, mean_psnr, mean_mse, mean_mae
886
+ )
887
+ )
888
+
889
+ return mean_ssim
890
+
891
+
892
+ ############### compare all function
893
+ def compare_all(
894
+ NetC,
895
+ NetG,
896
+ eval_loader,
897
+ p_mag = 0.05,
898
+ dop = 0.2,
899
+ num_runs = 100,
900
+ device='cuda',
901
+ dtype=torch.cuda.FloatTensor,
902
+ task=None,
903
+ ):
904
+ NetC.to(device)
905
+ NetC.eval()
906
+ NetG.to(device)
907
+ NetG.eval()
908
+
909
+ with tqdm(eval_loader, unit='batch') as tepoch:
910
+ for (idx, batch) in enumerate(tepoch):
911
+ tepoch.set_description('Comparing ...')
912
+ ##
913
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
914
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
915
+ if task == 'inpainting':
916
+ xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
917
+ xMask = xMask.to(device).type(dtype)
918
+ # pass them through the network
919
+ with torch.no_grad():
920
+ if task == 'inpainting':
921
+ _, xSR = NetG(xLR, xMask)
922
+ else:
923
+ xSR = NetG(xLR)
924
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
925
+
926
+ if task == 'inpainting':
927
+ xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs, task='inpainting', xMask=xMask)
928
+ xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs, task='inpainting', xMask=xMask)
929
+ xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR, task='inpainting', xMask=xMask)
930
+ else:
931
+ xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs)
932
+ xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs)
933
+ xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR)
934
+
935
+ print('bdg', xSRvar1.shape, xSRvar2.shape, xSRvar3.shape)
936
+
937
+ n_batch = xSR.shape[0]
938
+ for j in range(n_batch):
939
+ if task=='s':
940
+ show_SR_w_err(xLR[j], xHR[j], xSR[j])
941
+ show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
942
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j])
943
+ if task=='d':
944
+ show_SR_w_err(xLR[j], xHR[j], 0.5*xSR[j]+0.5*xHR[j])
945
+ show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
946
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
947
+ if task=='inpainting':
948
+ show_SR_w_err(xLR[j]*(1-xMask[j]), xHR[j], xSR[j], elim=(0,0.25), task='inpainting', xMask=xMask[j])
949
+ show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.45), torch.pow(xSRvar1[j], 0.4))
950
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
951
+ if task=='m':
952
+ show_SR_w_err(xLR[j], xHR[j], xSR[j], elim=(0,0.04), task='m')
953
+ show_uncer4(0.4*xSRvar1[j]+0.6*xSRvar2[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42), ulim=(0.02,0.15))
954
+ show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j], ulim=(0.02,0.15))
955
+
956
+
957
+ ################# Degrading Identity
958
+ def degrage_BayesCap_p(
959
+ NetC,
960
+ NetG,
961
+ eval_loader,
962
+ device='cuda',
963
+ dtype=torch.cuda.FloatTensor,
964
+ num_runs=50,
965
+ ):
966
+ NetC.to(device)
967
+ NetC.eval()
968
+ NetG.to(device)
969
+ NetG.eval()
970
+
971
+ p_mag_list = [0, 0.05, 0.1, 0.15, 0.2]
972
+ list_s = []
973
+ list_p = []
974
+ list_u1 = []
975
+ list_u2 = []
976
+ list_c = []
977
+ for p_mag in p_mag_list:
978
+ mean_ssim = 0
979
+ mean_psnr = 0
980
+ mean_mse = 0
981
+ mean_mae = 0
982
+ num_imgs = 0
983
+ list_error = []
984
+ list_error2 = []
985
+ list_var = []
986
+
987
+ with tqdm(eval_loader, unit='batch') as tepoch:
988
+ for (idx, batch) in enumerate(tepoch):
989
+ tepoch.set_description('Validating ...')
990
+ ##
991
+ xLR, xHR = batch[0].to(device), batch[1].to(device)
992
+ xLR, xHR = xLR.type(dtype), xHR.type(dtype)
993
+ # pass them through the network
994
+ with torch.no_grad():
995
+ xSR = NetG(xLR)
996
+ xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR + p_mag*xSR.max()*torch.randn_like(xSR))
997
+ a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
998
+ b_map = xSRC_beta.to('cpu').data
999
+ xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
1000
+ n_batch = xSRC_mu.shape[0]
1001
+ for j in range(n_batch):
1002
+ num_imgs += 1
1003
+ mean_ssim += img_ssim(xSRC_mu[j], xSR[j])
1004
+ mean_psnr += img_psnr(xSRC_mu[j], xSR[j])
1005
+ mean_mse += img_mse(xSRC_mu[j], xSR[j])
1006
+ mean_mae += img_mae(xSRC_mu[j], xSR[j])
1007
+
1008
+ error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
1009
+ error_map2 = torch.mean(torch.pow(torch.abs(xSRC_mu[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
1010
+ var_map = xSRvar[j].to('cpu').data.reshape(-1)
1011
+ list_error.extend(list(error_map.numpy()))
1012
+ list_error2.extend(list(error_map2.numpy()))
1013
+ list_var.extend(list(var_map.numpy()))
1014
+ ##
1015
+ mean_ssim /= num_imgs
1016
+ mean_psnr /= num_imgs
1017
+ mean_mse /= num_imgs
1018
+ mean_mae /= num_imgs
1019
+ print(
1020
+ 'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
1021
+ (
1022
+ mean_ssim, mean_psnr, mean_mse, mean_mae
1023
+ )
1024
+ )
1025
+ uce1 = get_UCE(list_error[::100], list_var[::100], num_bins=200)[1]
1026
+ uce2 = get_UCE(list_error2[::100], list_var[::100], num_bins=200)[1]
1027
+ print('UCE1: ', uce1)
1028
+ print('UCE2: ', uce2)
1029
+ list_s.append(mean_ssim.item())
1030
+ list_p.append(mean_psnr.item())
1031
+ list_u1.append(uce1)
1032
+ list_u2.append(uce2)
1033
+
1034
+ plt.plot(list_s)
1035
+ plt.show()
1036
+ plt.plot(list_p)
1037
+ plt.show()
1038
+
1039
+ plt.plot(list_u1, label='wrt SR output')
1040
+ plt.plot(list_u2, label='wrt BayesCap output')
1041
+ plt.legend()
1042
+ plt.show()
1043
+
1044
+ sns.set_style('darkgrid')
1045
+ fig,ax = plt.subplots()
1046
+ # make a plot
1047
+ ax.plot(p_mag_list, list_s, color="red", marker="o")
1048
+ # set x-axis label
1049
+ ax.set_xlabel("Reducing faithfulness of BayesCap Reconstruction",fontsize=10)
1050
+ # set y-axis label
1051
+ ax.set_ylabel("SSIM btwn BayesCap and SRGAN outputs", color="red",fontsize=10)
1052
+
1053
+ # twin object for two different y-axis on the sample plot
1054
+ ax2=ax.twinx()
1055
+ # make a plot with different y-axis using second axis object
1056
+ ax2.plot(p_mag_list, list_u1, color="blue", marker="o", label='UCE wrt to error btwn SRGAN output and GT')
1057
+ ax2.plot(p_mag_list, list_u2, color="orange", marker="o", label='UCE wrt to error btwn BayesCap output and GT')
1058
+ ax2.set_ylabel("UCE", color="green", fontsize=10)
1059
+ plt.legend(fontsize=10)
1060
+ plt.tight_layout()
1061
+ plt.show()
1062
+
1063
+ ################# DeepFill_v2
1064
+
1065
+ # ----------------------------------------
1066
+ # PATH processing
1067
+ # ----------------------------------------
1068
+ def text_readlines(filename):
1069
+ # Try to read a txt file and return a list.Return [] if there was a mistake.
1070
+ try:
1071
+ file = open(filename, 'r')
1072
+ except IOError:
1073
+ error = []
1074
+ return error
1075
+ content = file.readlines()
1076
+ # This for loop deletes the EOF (like \n)
1077
+ for i in range(len(content)):
1078
+ content[i] = content[i][:len(content[i])-1]
1079
+ file.close()
1080
+ return content
1081
+
1082
+ def savetxt(name, loss_log):
1083
+ np_loss_log = np.array(loss_log)
1084
+ np.savetxt(name, np_loss_log)
1085
+
1086
+ def get_files(path):
1087
+ # read a folder, return the complete path
1088
+ ret = []
1089
+ for root, dirs, files in os.walk(path):
1090
+ for filespath in files:
1091
+ ret.append(os.path.join(root, filespath))
1092
+ return ret
1093
+
1094
+ def get_names(path):
1095
+ # read a folder, return the image name
1096
+ ret = []
1097
+ for root, dirs, files in os.walk(path):
1098
+ for filespath in files:
1099
+ ret.append(filespath)
1100
+ return ret
1101
+
1102
+ def text_save(content, filename, mode = 'a'):
1103
+ # save a list to a txt
1104
+ # Try to save a list variable in txt file.
1105
+ file = open(filename, mode)
1106
+ for i in range(len(content)):
1107
+ file.write(str(content[i]) + '\n')
1108
+ file.close()
1109
+
1110
+ def check_path(path):
1111
+ if not os.path.exists(path):
1112
+ os.makedirs(path)
1113
+
1114
+ # ----------------------------------------
1115
+ # Validation and Sample at training
1116
+ # ----------------------------------------
1117
+ def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255):
1118
+ # Save image one-by-one
1119
+ for i in range(len(img_list)):
1120
+ img = img_list[i]
1121
+ # Recover normalization: * 255 because last layer is sigmoid activated
1122
+ img = img * 255
1123
+ # Process img_copy and do not destroy the data of img
1124
+ img_copy = img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
1125
+ img_copy = np.clip(img_copy, 0, pixel_max_cnt)
1126
+ img_copy = img_copy.astype(np.uint8)
1127
+ img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
1128
+ # Save to certain path
1129
+ save_img_name = sample_name + '_' + name_list[i] + '.jpg'
1130
+ save_img_path = os.path.join(sample_folder, save_img_name)
1131
+ cv2.imwrite(save_img_path, img_copy)
1132
+
1133
+ def psnr(pred, target, pixel_max_cnt = 255):
1134
+ mse = torch.mul(target - pred, target - pred)
1135
+ rmse_avg = (torch.mean(mse).item()) ** 0.5
1136
+ p = 20 * np.log10(pixel_max_cnt / rmse_avg)
1137
+ return p
1138
+
1139
+ def grey_psnr(pred, target, pixel_max_cnt = 255):
1140
+ pred = torch.sum(pred, dim = 0)
1141
+ target = torch.sum(target, dim = 0)
1142
+ mse = torch.mul(target - pred, target - pred)
1143
+ rmse_avg = (torch.mean(mse).item()) ** 0.5
1144
+ p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
1145
+ return p
1146
+
1147
+ def ssim(pred, target):
1148
+ pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy()
1149
+ target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy()
1150
+ target = target[0]
1151
+ pred = pred[0]
1152
+ ssim = skimage.measure.compare_ssim(target, pred, multichannel = True)
1153
+ return ssim
1154
+
1155
+ ## for contextual attention
1156
+
1157
+ def extract_image_patches(images, ksizes, strides, rates, padding='same'):
1158
+ """
1159
+ Extract patches from images and put them in the C output dimension.
1160
+ :param padding:
1161
+ :param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
1162
+ :param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
1163
+ each dimension of images
1164
+ :param strides: [stride_rows, stride_cols]
1165
+ :param rates: [dilation_rows, dilation_cols]
1166
+ :return: A Tensor
1167
+ """
1168
+ assert len(images.size()) == 4
1169
+ assert padding in ['same', 'valid']
1170
+ batch_size, channel, height, width = images.size()
1171
+
1172
+ if padding == 'same':
1173
+ images = same_padding(images, ksizes, strides, rates)
1174
+ elif padding == 'valid':
1175
+ pass
1176
+ else:
1177
+ raise NotImplementedError('Unsupported padding type: {}.\
1178
+ Only "same" or "valid" are supported.'.format(padding))
1179
+
1180
+ unfold = torch.nn.Unfold(kernel_size=ksizes,
1181
+ dilation=rates,
1182
+ padding=0,
1183
+ stride=strides)
1184
+ patches = unfold(images)
1185
+ return patches # [N, C*k*k, L], L is the total number of such blocks
1186
+
1187
+ def same_padding(images, ksizes, strides, rates):
1188
+ assert len(images.size()) == 4
1189
+ batch_size, channel, rows, cols = images.size()
1190
+ out_rows = (rows + strides[0] - 1) // strides[0]
1191
+ out_cols = (cols + strides[1] - 1) // strides[1]
1192
+ effective_k_row = (ksizes[0] - 1) * rates[0] + 1
1193
+ effective_k_col = (ksizes[1] - 1) * rates[1] + 1
1194
+ padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
1195
+ padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
1196
+ # Pad the input
1197
+ padding_top = int(padding_rows / 2.)
1198
+ padding_left = int(padding_cols / 2.)
1199
+ padding_bottom = padding_rows - padding_top
1200
+ padding_right = padding_cols - padding_left
1201
+ paddings = (padding_left, padding_right, padding_top, padding_bottom)
1202
+ images = torch.nn.ZeroPad2d(paddings)(images)
1203
+ return images
1204
+
1205
+ def reduce_mean(x, axis=None, keepdim=False):
1206
+ if not axis:
1207
+ axis = range(len(x.shape))
1208
+ for i in sorted(axis, reverse=True):
1209
+ x = torch.mean(x, dim=i, keepdim=keepdim)
1210
+ return x
1211
+
1212
+
1213
+ def reduce_std(x, axis=None, keepdim=False):
1214
+ if not axis:
1215
+ axis = range(len(x.shape))
1216
+ for i in sorted(axis, reverse=True):
1217
+ x = torch.std(x, dim=i, keepdim=keepdim)
1218
+ return x
1219
+
1220
+
1221
+ def reduce_sum(x, axis=None, keepdim=False):
1222
+ if not axis:
1223
+ axis = range(len(x.shape))
1224
+ for i in sorted(axis, reverse=True):
1225
+ x = torch.sum(x, dim=i, keepdim=keepdim)
1226
+ return x
1227
+
1228
+ def random_mask(num_batch=1, mask_shape=(256,256)):
1229
+ list_mask = []
1230
+ for _ in range(num_batch):
1231
+ # rectangle mask
1232
+ image_height = mask_shape[0]
1233
+ image_width = mask_shape[1]
1234
+ max_delta_height = image_height//8
1235
+ max_delta_width = image_width//8
1236
+ height = image_height//4
1237
+ width = image_width//4
1238
+ max_t = image_height - height
1239
+ max_l = image_width - width
1240
+ t = random.randint(0, max_t)
1241
+ l = random.randint(0, max_l)
1242
+ # bbox = (t, l, height, width)
1243
+ h = random.randint(0, max_delta_height//2)
1244
+ w = random.randint(0, max_delta_width//2)
1245
+ mask = torch.zeros((1, 1, image_height, image_width))
1246
+ mask[:, :, t+h:t+height-h, l+w:l+width-w] = 1
1247
+ rect_mask = mask
1248
+
1249
+ # brush mask
1250
+ min_num_vertex = 4
1251
+ max_num_vertex = 12
1252
+ mean_angle = 2 * math.pi / 5
1253
+ angle_range = 2 * math.pi / 15
1254
+ min_width = 12
1255
+ max_width = 40
1256
+ H, W = image_height, image_width
1257
+ average_radius = math.sqrt(H*H+W*W) / 8
1258
+ mask = Image.new('L', (W, H), 0)
1259
+
1260
+ for _ in range(np.random.randint(1, 4)):
1261
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
1262
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
1263
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
1264
+ angles = []
1265
+ vertex = []
1266
+ for i in range(num_vertex):
1267
+ if i % 2 == 0:
1268
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
1269
+ else:
1270
+ angles.append(np.random.uniform(angle_min, angle_max))
1271
+
1272
+ h, w = mask.size
1273
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
1274
+ for i in range(num_vertex):
1275
+ r = np.clip(
1276
+ np.random.normal(loc=average_radius, scale=average_radius//2),
1277
+ 0, 2*average_radius)
1278
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
1279
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
1280
+ vertex.append((int(new_x), int(new_y)))
1281
+
1282
+ draw = ImageDraw.Draw(mask)
1283
+ width = int(np.random.uniform(min_width, max_width))
1284
+ draw.line(vertex, fill=255, width=width)
1285
+ for v in vertex:
1286
+ draw.ellipse((v[0] - width//2,
1287
+ v[1] - width//2,
1288
+ v[0] + width//2,
1289
+ v[1] + width//2),
1290
+ fill=255)
1291
+
1292
+ if np.random.normal() > 0:
1293
+ mask.transpose(Image.FLIP_LEFT_RIGHT)
1294
+ if np.random.normal() > 0:
1295
+ mask.transpose(Image.FLIP_TOP_BOTTOM)
1296
+
1297
+ mask = transforms.ToTensor()(mask)
1298
+ mask = mask.reshape((1, 1, H, W))
1299
+ brush_mask = mask
1300
+
1301
+ mask = torch.cat([rect_mask, brush_mask], dim=1).max(dim=1, keepdim=True)[0]
1302
+ list_mask.append(mask)
1303
+ mask = torch.cat(list_mask, dim=0)
1304
+ return mask