Spaces:
Runtime error
Runtime error
hfspace gradio demo
Browse files- LICENSE +201 -0
- README.md +2 -13
- app.py +119 -0
- demo_examples/baby.png +0 -0
- demo_examples/bird.png +0 -0
- demo_examples/butterfly.png +0 -0
- demo_examples/head.png +0 -0
- demo_examples/woman.png +0 -0
- ds.py +485 -0
- losses.py +131 -0
- networks_SRGAN.py +347 -0
- networks_T1toT2.py +477 -0
- requirements.txt +334 -0
- src/.gitkeep +0 -0
- src/__pycache__/ds.cpython-310.pyc +0 -0
- src/__pycache__/losses.cpython-310.pyc +0 -0
- src/__pycache__/networks_SRGAN.cpython-310.pyc +0 -0
- src/__pycache__/utils.cpython-310.pyc +0 -0
- src/app.py +115 -0
- src/ds.py +485 -0
- src/flagged/Alpha/0.png +0 -0
- src/flagged/Beta/0.png +0 -0
- src/flagged/Low-res/0.png +0 -0
- src/flagged/Orignal/0.png +0 -0
- src/flagged/Super-res/0.png +0 -0
- src/flagged/Uncertainty/0.png +0 -0
- src/flagged/log.csv +2 -0
- src/losses.py +131 -0
- src/networks_SRGAN.py +347 -0
- src/networks_T1toT2.py +477 -0
- src/utils.py +1273 -0
- utils.py +1304 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,13 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
emoji: 🦀
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: green
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 3.0.24
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: mit
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
# BayesCap
|
2 |
+
Bayesian Identity Cap for Calibrated Uncertainty in Pretrained Neural Networks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from matplotlib import cm
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.models as models
|
10 |
+
from torch.utils.data import Dataset, DataLoader
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from ds import *
|
17 |
+
from losses import *
|
18 |
+
from networks_SRGAN import *
|
19 |
+
from utils import *
|
20 |
+
|
21 |
+
device = 'cuda'
|
22 |
+
|
23 |
+
|
24 |
+
NetG = Generator()
|
25 |
+
model_parameters = filter(lambda p: True, NetG.parameters())
|
26 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
27 |
+
print("Number of Parameters:", params)
|
28 |
+
NetC = BayesCap(in_channels=3, out_channels=3)
|
29 |
+
|
30 |
+
ensure_checkpoint_exists('BayesCap_SRGAN.pth')
|
31 |
+
NetG.load_state_dict(torch.load('BayesCap_SRGAN.pth', map_location=device))
|
32 |
+
NetG.to(device)
|
33 |
+
NetG.eval()
|
34 |
+
|
35 |
+
ensure_checkpoint_exists('BayesCap_ckpt.pth')
|
36 |
+
NetC.load_state_dict(torch.load('BayesCap_ckpt.pth', map_location=device))
|
37 |
+
NetC.to(device)
|
38 |
+
NetC.eval()
|
39 |
+
|
40 |
+
def tensor01_to_pil(xt):
|
41 |
+
r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
|
42 |
+
return r
|
43 |
+
|
44 |
+
|
45 |
+
def predict(img):
|
46 |
+
"""
|
47 |
+
img: image
|
48 |
+
"""
|
49 |
+
image_size = (256,256)
|
50 |
+
upscale_factor = 4
|
51 |
+
lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
52 |
+
# lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
|
53 |
+
|
54 |
+
img = Image.fromarray(np.array(img))
|
55 |
+
img = lr_transforms(img)
|
56 |
+
lr_tensor = utils.image2tensor(img, range_norm=False, half=False)
|
57 |
+
|
58 |
+
device = 'cuda'
|
59 |
+
dtype = torch.cuda.FloatTensor
|
60 |
+
xLR = lr_tensor.to(device).unsqueeze(0)
|
61 |
+
xLR = xLR.type(dtype)
|
62 |
+
# pass them through the network
|
63 |
+
with torch.no_grad():
|
64 |
+
xSR = NetG(xLR)
|
65 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
66 |
+
|
67 |
+
a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
|
68 |
+
b_map = xSRC_beta[0].to('cpu').data
|
69 |
+
u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
70 |
+
|
71 |
+
|
72 |
+
x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
73 |
+
|
74 |
+
x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
75 |
+
|
76 |
+
#im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
|
77 |
+
|
78 |
+
a_map = torch.clamp(a_map, min=0, max=0.1)
|
79 |
+
a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
|
80 |
+
x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
81 |
+
|
82 |
+
b_map = torch.clamp(b_map, min=0.45, max=0.75)
|
83 |
+
b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
|
84 |
+
x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
85 |
+
|
86 |
+
u_map = torch.clamp(u_map, min=0, max=0.15)
|
87 |
+
u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
|
88 |
+
x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
89 |
+
|
90 |
+
return x_LR, x_mean, x_alpha, x_beta, x_uncer
|
91 |
+
|
92 |
+
import gradio as gr
|
93 |
+
|
94 |
+
title = "BayesCap"
|
95 |
+
description = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks (ECCV 2022)"
|
96 |
+
article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"
|
97 |
+
|
98 |
+
|
99 |
+
gr.Interface(
|
100 |
+
fn=predict,
|
101 |
+
inputs=gr.inputs.Image(type='pil', label="Orignal"),
|
102 |
+
outputs=[
|
103 |
+
gr.outputs.Image(type='pil', label="Low-res"),
|
104 |
+
gr.outputs.Image(type='pil', label="Super-res"),
|
105 |
+
gr.outputs.Image(type='pil', label="Alpha"),
|
106 |
+
gr.outputs.Image(type='pil', label="Beta"),
|
107 |
+
gr.outputs.Image(type='pil', label="Uncertainty")
|
108 |
+
],
|
109 |
+
title=title,
|
110 |
+
description=description,
|
111 |
+
article=article,
|
112 |
+
examples=[
|
113 |
+
["./demo_examples/baby.png"],
|
114 |
+
["./demo_examples/bird.png"],
|
115 |
+
["./demo_examples/butterfly.png"],
|
116 |
+
["./demo_examples/head.png"],
|
117 |
+
["./demo_examples/woman.png"],
|
118 |
+
]
|
119 |
+
).launch(share=True)
|
demo_examples/baby.png
ADDED
demo_examples/bird.png
ADDED
demo_examples/butterfly.png
ADDED
demo_examples/head.png
ADDED
demo_examples/woman.png
ADDED
ds.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import random
|
4 |
+
import copy
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import skimage.transform
|
10 |
+
from collections import Counter
|
11 |
+
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.utils.data as data
|
15 |
+
from torch import Tensor
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
from torchvision import transforms
|
18 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
19 |
+
|
20 |
+
import utils
|
21 |
+
|
22 |
+
class ImgDset(Dataset):
|
23 |
+
"""Customize the data set loading function and prepare low/high resolution image data in advance.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
dataroot (str): Training data set address
|
27 |
+
image_size (int): High resolution image size
|
28 |
+
upscale_factor (int): Image magnification
|
29 |
+
mode (str): Data set loading method, the training data set is for data enhancement,
|
30 |
+
and the verification data set is not for data enhancement
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
|
35 |
+
super(ImgDset, self).__init__()
|
36 |
+
self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
|
37 |
+
|
38 |
+
if mode == "train":
|
39 |
+
self.hr_transforms = transforms.Compose([
|
40 |
+
transforms.RandomCrop(image_size),
|
41 |
+
transforms.RandomRotation(90),
|
42 |
+
transforms.RandomHorizontalFlip(0.5),
|
43 |
+
])
|
44 |
+
else:
|
45 |
+
self.hr_transforms = transforms.Resize(image_size)
|
46 |
+
|
47 |
+
self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
48 |
+
|
49 |
+
def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
|
50 |
+
# Read a batch of image data
|
51 |
+
image = Image.open(self.filenames[batch_index])
|
52 |
+
|
53 |
+
# Transform image
|
54 |
+
hr_image = self.hr_transforms(image)
|
55 |
+
lr_image = self.lr_transforms(hr_image)
|
56 |
+
|
57 |
+
# Convert image data into Tensor stream format (PyTorch).
|
58 |
+
# Note: The range of input and output is between [0, 1]
|
59 |
+
lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False)
|
60 |
+
hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False)
|
61 |
+
|
62 |
+
return lr_tensor, hr_tensor
|
63 |
+
|
64 |
+
def __len__(self) -> int:
|
65 |
+
return len(self.filenames)
|
66 |
+
|
67 |
+
|
68 |
+
class PairedImages_w_nameList(Dataset):
|
69 |
+
'''
|
70 |
+
can act as supervised or un-supervised based on flists
|
71 |
+
'''
|
72 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
73 |
+
self.flist1 = flist1
|
74 |
+
self.flist2 = flist2
|
75 |
+
self.transform1 = transform1
|
76 |
+
self.transform2 = transform2
|
77 |
+
self.do_aug = do_aug
|
78 |
+
def __getitem__(self, index):
|
79 |
+
impath1 = self.flist1[index]
|
80 |
+
img1 = Image.open(impath1).convert('RGB')
|
81 |
+
impath2 = self.flist2[index]
|
82 |
+
img2 = Image.open(impath2).convert('RGB')
|
83 |
+
|
84 |
+
img1 = utils.image2tensor(img1, range_norm=False, half=False)
|
85 |
+
img2 = utils.image2tensor(img2, range_norm=False, half=False)
|
86 |
+
|
87 |
+
if self.transform1 is not None:
|
88 |
+
img1 = self.transform1(img1)
|
89 |
+
if self.transform2 is not None:
|
90 |
+
img2 = self.transform2(img2)
|
91 |
+
|
92 |
+
return img1, img2
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.flist1)
|
95 |
+
|
96 |
+
class PairedImages_w_nameList_npy(Dataset):
|
97 |
+
'''
|
98 |
+
can act as supervised or un-supervised based on flists
|
99 |
+
'''
|
100 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
101 |
+
self.flist1 = flist1
|
102 |
+
self.flist2 = flist2
|
103 |
+
self.transform1 = transform1
|
104 |
+
self.transform2 = transform2
|
105 |
+
self.do_aug = do_aug
|
106 |
+
def __getitem__(self, index):
|
107 |
+
impath1 = self.flist1[index]
|
108 |
+
img1 = np.load(impath1)
|
109 |
+
impath2 = self.flist2[index]
|
110 |
+
img2 = np.load(impath2)
|
111 |
+
|
112 |
+
if self.transform1 is not None:
|
113 |
+
img1 = self.transform1(img1)
|
114 |
+
if self.transform2 is not None:
|
115 |
+
img2 = self.transform2(img2)
|
116 |
+
|
117 |
+
return img1, img2
|
118 |
+
def __len__(self):
|
119 |
+
return len(self.flist1)
|
120 |
+
|
121 |
+
# def call_paired():
|
122 |
+
# root1='./GOPRO_3840FPS_AVG_3-21/train/blur/'
|
123 |
+
# root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/'
|
124 |
+
|
125 |
+
# flist1=glob.glob(root1+'/*/*.png')
|
126 |
+
# flist2=glob.glob(root2+'/*/*.png')
|
127 |
+
|
128 |
+
# dset = PairedImages_w_nameList(root1,root2,flist1,flist2)
|
129 |
+
|
130 |
+
#### KITTI depth
|
131 |
+
|
132 |
+
def load_velodyne_points(filename):
|
133 |
+
"""Load 3D point cloud from KITTI file format
|
134 |
+
(adapted from https://github.com/hunse/kitti)
|
135 |
+
"""
|
136 |
+
points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4)
|
137 |
+
points[:, 3] = 1.0 # homogeneous
|
138 |
+
return points
|
139 |
+
|
140 |
+
|
141 |
+
def read_calib_file(path):
|
142 |
+
"""Read KITTI calibration file
|
143 |
+
(from https://github.com/hunse/kitti)
|
144 |
+
"""
|
145 |
+
float_chars = set("0123456789.e+- ")
|
146 |
+
data = {}
|
147 |
+
with open(path, 'r') as f:
|
148 |
+
for line in f.readlines():
|
149 |
+
key, value = line.split(':', 1)
|
150 |
+
value = value.strip()
|
151 |
+
data[key] = value
|
152 |
+
if float_chars.issuperset(value):
|
153 |
+
# try to cast to float array
|
154 |
+
try:
|
155 |
+
data[key] = np.array(list(map(float, value.split(' '))))
|
156 |
+
except ValueError:
|
157 |
+
# casting error: data[key] already eq. value, so pass
|
158 |
+
pass
|
159 |
+
|
160 |
+
return data
|
161 |
+
|
162 |
+
|
163 |
+
def sub2ind(matrixSize, rowSub, colSub):
|
164 |
+
"""Convert row, col matrix subscripts to linear indices
|
165 |
+
"""
|
166 |
+
m, n = matrixSize
|
167 |
+
return rowSub * (n-1) + colSub - 1
|
168 |
+
|
169 |
+
|
170 |
+
def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False):
|
171 |
+
"""Generate a depth map from velodyne data
|
172 |
+
"""
|
173 |
+
# load calibration files
|
174 |
+
cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt'))
|
175 |
+
velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt'))
|
176 |
+
velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis]))
|
177 |
+
velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0])))
|
178 |
+
|
179 |
+
# get image shape
|
180 |
+
im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32)
|
181 |
+
|
182 |
+
# compute projection matrix velodyne->image plane
|
183 |
+
R_cam2rect = np.eye(4)
|
184 |
+
R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3)
|
185 |
+
P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4)
|
186 |
+
P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam)
|
187 |
+
|
188 |
+
# load velodyne points and remove all behind image plane (approximation)
|
189 |
+
# each row of the velodyne data is forward, left, up, reflectance
|
190 |
+
velo = load_velodyne_points(velo_filename)
|
191 |
+
velo = velo[velo[:, 0] >= 0, :]
|
192 |
+
|
193 |
+
# project the points to the camera
|
194 |
+
velo_pts_im = np.dot(P_velo2im, velo.T).T
|
195 |
+
velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis]
|
196 |
+
|
197 |
+
if vel_depth:
|
198 |
+
velo_pts_im[:, 2] = velo[:, 0]
|
199 |
+
|
200 |
+
# check if in bounds
|
201 |
+
# use minus 1 to get the exact same value as KITTI matlab code
|
202 |
+
velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
|
203 |
+
velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
|
204 |
+
val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
|
205 |
+
val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0])
|
206 |
+
velo_pts_im = velo_pts_im[val_inds, :]
|
207 |
+
|
208 |
+
# project to image
|
209 |
+
depth = np.zeros((im_shape[:2]))
|
210 |
+
depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2]
|
211 |
+
|
212 |
+
# find the duplicate points and choose the closest depth
|
213 |
+
inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0])
|
214 |
+
dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
|
215 |
+
for dd in dupe_inds:
|
216 |
+
pts = np.where(inds == dd)[0]
|
217 |
+
x_loc = int(velo_pts_im[pts[0], 0])
|
218 |
+
y_loc = int(velo_pts_im[pts[0], 1])
|
219 |
+
depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
|
220 |
+
depth[depth < 0] = 0
|
221 |
+
|
222 |
+
return depth
|
223 |
+
|
224 |
+
def pil_loader(path):
|
225 |
+
# open path as file to avoid ResourceWarning
|
226 |
+
# (https://github.com/python-pillow/Pillow/issues/835)
|
227 |
+
with open(path, 'rb') as f:
|
228 |
+
with Image.open(f) as img:
|
229 |
+
return img.convert('RGB')
|
230 |
+
|
231 |
+
|
232 |
+
class MonoDataset(data.Dataset):
|
233 |
+
"""Superclass for monocular dataloaders
|
234 |
+
|
235 |
+
Args:
|
236 |
+
data_path
|
237 |
+
filenames
|
238 |
+
height
|
239 |
+
width
|
240 |
+
frame_idxs
|
241 |
+
num_scales
|
242 |
+
is_train
|
243 |
+
img_ext
|
244 |
+
"""
|
245 |
+
def __init__(self,
|
246 |
+
data_path,
|
247 |
+
filenames,
|
248 |
+
height,
|
249 |
+
width,
|
250 |
+
frame_idxs,
|
251 |
+
num_scales,
|
252 |
+
is_train=False,
|
253 |
+
img_ext='.jpg'):
|
254 |
+
super(MonoDataset, self).__init__()
|
255 |
+
|
256 |
+
self.data_path = data_path
|
257 |
+
self.filenames = filenames
|
258 |
+
self.height = height
|
259 |
+
self.width = width
|
260 |
+
self.num_scales = num_scales
|
261 |
+
self.interp = Image.ANTIALIAS
|
262 |
+
|
263 |
+
self.frame_idxs = frame_idxs
|
264 |
+
|
265 |
+
self.is_train = is_train
|
266 |
+
self.img_ext = img_ext
|
267 |
+
|
268 |
+
self.loader = pil_loader
|
269 |
+
self.to_tensor = transforms.ToTensor()
|
270 |
+
|
271 |
+
# We need to specify augmentations differently in newer versions of torchvision.
|
272 |
+
# We first try the newer tuple version; if this fails we fall back to scalars
|
273 |
+
try:
|
274 |
+
self.brightness = (0.8, 1.2)
|
275 |
+
self.contrast = (0.8, 1.2)
|
276 |
+
self.saturation = (0.8, 1.2)
|
277 |
+
self.hue = (-0.1, 0.1)
|
278 |
+
transforms.ColorJitter.get_params(
|
279 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
280 |
+
except TypeError:
|
281 |
+
self.brightness = 0.2
|
282 |
+
self.contrast = 0.2
|
283 |
+
self.saturation = 0.2
|
284 |
+
self.hue = 0.1
|
285 |
+
|
286 |
+
self.resize = {}
|
287 |
+
for i in range(self.num_scales):
|
288 |
+
s = 2 ** i
|
289 |
+
self.resize[i] = transforms.Resize((self.height // s, self.width // s),
|
290 |
+
interpolation=self.interp)
|
291 |
+
|
292 |
+
self.load_depth = self.check_depth()
|
293 |
+
|
294 |
+
def preprocess(self, inputs, color_aug):
|
295 |
+
"""Resize colour images to the required scales and augment if required
|
296 |
+
|
297 |
+
We create the color_aug object in advance and apply the same augmentation to all
|
298 |
+
images in this item. This ensures that all images input to the pose network receive the
|
299 |
+
same augmentation.
|
300 |
+
"""
|
301 |
+
for k in list(inputs):
|
302 |
+
frame = inputs[k]
|
303 |
+
if "color" in k:
|
304 |
+
n, im, i = k
|
305 |
+
for i in range(self.num_scales):
|
306 |
+
inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
|
307 |
+
|
308 |
+
for k in list(inputs):
|
309 |
+
f = inputs[k]
|
310 |
+
if "color" in k:
|
311 |
+
n, im, i = k
|
312 |
+
inputs[(n, im, i)] = self.to_tensor(f)
|
313 |
+
inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
|
314 |
+
|
315 |
+
def __len__(self):
|
316 |
+
return len(self.filenames)
|
317 |
+
|
318 |
+
def __getitem__(self, index):
|
319 |
+
"""Returns a single training item from the dataset as a dictionary.
|
320 |
+
|
321 |
+
Values correspond to torch tensors.
|
322 |
+
Keys in the dictionary are either strings or tuples:
|
323 |
+
|
324 |
+
("color", <frame_id>, <scale>) for raw colour images,
|
325 |
+
("color_aug", <frame_id>, <scale>) for augmented colour images,
|
326 |
+
("K", scale) or ("inv_K", scale) for camera intrinsics,
|
327 |
+
"stereo_T" for camera extrinsics, and
|
328 |
+
"depth_gt" for ground truth depth maps.
|
329 |
+
|
330 |
+
<frame_id> is either:
|
331 |
+
an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
|
332 |
+
or
|
333 |
+
"s" for the opposite image in the stereo pair.
|
334 |
+
|
335 |
+
<scale> is an integer representing the scale of the image relative to the fullsize image:
|
336 |
+
-1 images at native resolution as loaded from disk
|
337 |
+
0 images resized to (self.width, self.height )
|
338 |
+
1 images resized to (self.width // 2, self.height // 2)
|
339 |
+
2 images resized to (self.width // 4, self.height // 4)
|
340 |
+
3 images resized to (self.width // 8, self.height // 8)
|
341 |
+
"""
|
342 |
+
inputs = {}
|
343 |
+
|
344 |
+
do_color_aug = self.is_train and random.random() > 0.5
|
345 |
+
do_flip = self.is_train and random.random() > 0.5
|
346 |
+
|
347 |
+
line = self.filenames[index].split()
|
348 |
+
folder = line[0]
|
349 |
+
|
350 |
+
if len(line) == 3:
|
351 |
+
frame_index = int(line[1])
|
352 |
+
else:
|
353 |
+
frame_index = 0
|
354 |
+
|
355 |
+
if len(line) == 3:
|
356 |
+
side = line[2]
|
357 |
+
else:
|
358 |
+
side = None
|
359 |
+
|
360 |
+
for i in self.frame_idxs:
|
361 |
+
if i == "s":
|
362 |
+
other_side = {"r": "l", "l": "r"}[side]
|
363 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
|
364 |
+
else:
|
365 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
|
366 |
+
|
367 |
+
# adjusting intrinsics to match each scale in the pyramid
|
368 |
+
for scale in range(self.num_scales):
|
369 |
+
K = self.K.copy()
|
370 |
+
|
371 |
+
K[0, :] *= self.width // (2 ** scale)
|
372 |
+
K[1, :] *= self.height // (2 ** scale)
|
373 |
+
|
374 |
+
inv_K = np.linalg.pinv(K)
|
375 |
+
|
376 |
+
inputs[("K", scale)] = torch.from_numpy(K)
|
377 |
+
inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
|
378 |
+
|
379 |
+
if do_color_aug:
|
380 |
+
color_aug = transforms.ColorJitter.get_params(
|
381 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
382 |
+
else:
|
383 |
+
color_aug = (lambda x: x)
|
384 |
+
|
385 |
+
self.preprocess(inputs, color_aug)
|
386 |
+
|
387 |
+
for i in self.frame_idxs:
|
388 |
+
del inputs[("color", i, -1)]
|
389 |
+
del inputs[("color_aug", i, -1)]
|
390 |
+
|
391 |
+
if self.load_depth:
|
392 |
+
depth_gt = self.get_depth(folder, frame_index, side, do_flip)
|
393 |
+
inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
|
394 |
+
inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
|
395 |
+
|
396 |
+
if "s" in self.frame_idxs:
|
397 |
+
stereo_T = np.eye(4, dtype=np.float32)
|
398 |
+
baseline_sign = -1 if do_flip else 1
|
399 |
+
side_sign = -1 if side == "l" else 1
|
400 |
+
stereo_T[0, 3] = side_sign * baseline_sign * 0.1
|
401 |
+
|
402 |
+
inputs["stereo_T"] = torch.from_numpy(stereo_T)
|
403 |
+
|
404 |
+
return inputs
|
405 |
+
|
406 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
407 |
+
raise NotImplementedError
|
408 |
+
|
409 |
+
def check_depth(self):
|
410 |
+
raise NotImplementedError
|
411 |
+
|
412 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
413 |
+
raise NotImplementedError
|
414 |
+
|
415 |
+
class KITTIDataset(MonoDataset):
|
416 |
+
"""Superclass for different types of KITTI dataset loaders
|
417 |
+
"""
|
418 |
+
def __init__(self, *args, **kwargs):
|
419 |
+
super(KITTIDataset, self).__init__(*args, **kwargs)
|
420 |
+
|
421 |
+
# NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
|
422 |
+
# To normalize you need to scale the first row by 1 / image_width and the second row
|
423 |
+
# by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
|
424 |
+
# If your principal point is far from the center you might need to disable the horizontal
|
425 |
+
# flip augmentation.
|
426 |
+
self.K = np.array([[0.58, 0, 0.5, 0],
|
427 |
+
[0, 1.92, 0.5, 0],
|
428 |
+
[0, 0, 1, 0],
|
429 |
+
[0, 0, 0, 1]], dtype=np.float32)
|
430 |
+
|
431 |
+
self.full_res_shape = (1242, 375)
|
432 |
+
self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
|
433 |
+
|
434 |
+
def check_depth(self):
|
435 |
+
line = self.filenames[0].split()
|
436 |
+
scene_name = line[0]
|
437 |
+
frame_index = int(line[1])
|
438 |
+
|
439 |
+
velo_filename = os.path.join(
|
440 |
+
self.data_path,
|
441 |
+
scene_name,
|
442 |
+
"velodyne_points/data/{:010d}.bin".format(int(frame_index)))
|
443 |
+
|
444 |
+
return os.path.isfile(velo_filename)
|
445 |
+
|
446 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
447 |
+
color = self.loader(self.get_image_path(folder, frame_index, side))
|
448 |
+
|
449 |
+
if do_flip:
|
450 |
+
color = color.transpose(Image.FLIP_LEFT_RIGHT)
|
451 |
+
|
452 |
+
return color
|
453 |
+
|
454 |
+
|
455 |
+
class KITTIDepthDataset(KITTIDataset):
|
456 |
+
"""KITTI dataset which uses the updated ground truth depth maps
|
457 |
+
"""
|
458 |
+
def __init__(self, *args, **kwargs):
|
459 |
+
super(KITTIDepthDataset, self).__init__(*args, **kwargs)
|
460 |
+
|
461 |
+
def get_image_path(self, folder, frame_index, side):
|
462 |
+
f_str = "{:010d}{}".format(frame_index, self.img_ext)
|
463 |
+
image_path = os.path.join(
|
464 |
+
self.data_path,
|
465 |
+
folder,
|
466 |
+
"image_0{}/data".format(self.side_map[side]),
|
467 |
+
f_str)
|
468 |
+
return image_path
|
469 |
+
|
470 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
471 |
+
f_str = "{:010d}.png".format(frame_index)
|
472 |
+
depth_path = os.path.join(
|
473 |
+
self.data_path,
|
474 |
+
folder,
|
475 |
+
"proj_depth/groundtruth/image_0{}".format(self.side_map[side]),
|
476 |
+
f_str)
|
477 |
+
|
478 |
+
depth_gt = Image.open(depth_path)
|
479 |
+
depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST)
|
480 |
+
depth_gt = np.array(depth_gt).astype(np.float32) / 256
|
481 |
+
|
482 |
+
if do_flip:
|
483 |
+
depth_gt = np.fliplr(depth_gt)
|
484 |
+
|
485 |
+
return depth_gt
|
losses.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as models
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
class ContentLoss(nn.Module):
|
8 |
+
"""Constructs a content loss function based on the VGG19 network.
|
9 |
+
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
|
10 |
+
|
11 |
+
Paper reference list:
|
12 |
+
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
|
13 |
+
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
|
14 |
+
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
|
15 |
+
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self) -> None:
|
19 |
+
super(ContentLoss, self).__init__()
|
20 |
+
# Load the VGG19 model trained on the ImageNet dataset.
|
21 |
+
vgg19 = models.vgg19(pretrained=True).eval()
|
22 |
+
# Extract the thirty-sixth layer output in the VGG19 model as the content loss.
|
23 |
+
self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
|
24 |
+
# Freeze model parameters.
|
25 |
+
for parameters in self.feature_extractor.parameters():
|
26 |
+
parameters.requires_grad = False
|
27 |
+
|
28 |
+
# The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
|
29 |
+
self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
30 |
+
self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
31 |
+
|
32 |
+
def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
|
33 |
+
# Standardized operations
|
34 |
+
sr = sr.sub(self.mean).div(self.std)
|
35 |
+
hr = hr.sub(self.mean).div(self.std)
|
36 |
+
|
37 |
+
# Find the feature map difference between the two images
|
38 |
+
loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
|
39 |
+
|
40 |
+
return loss
|
41 |
+
|
42 |
+
|
43 |
+
class GenGaussLoss(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self, reduction='mean',
|
46 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
47 |
+
resi_min = 1e-4, resi_max=1e3
|
48 |
+
) -> None:
|
49 |
+
super(GenGaussLoss, self).__init__()
|
50 |
+
self.reduction = reduction
|
51 |
+
self.alpha_eps = alpha_eps
|
52 |
+
self.beta_eps = beta_eps
|
53 |
+
self.resi_min = resi_min
|
54 |
+
self.resi_max = resi_max
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
|
59 |
+
):
|
60 |
+
one_over_alpha1 = one_over_alpha + self.alpha_eps
|
61 |
+
beta1 = beta + self.beta_eps
|
62 |
+
|
63 |
+
resi = torch.abs(mean - target)
|
64 |
+
# resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
|
65 |
+
resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
|
66 |
+
## check if resi has nans
|
67 |
+
if torch.sum(resi != resi) > 0:
|
68 |
+
print('resi has nans!!')
|
69 |
+
return None
|
70 |
+
|
71 |
+
log_one_over_alpha = torch.log(one_over_alpha1)
|
72 |
+
log_beta = torch.log(beta1)
|
73 |
+
lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
|
74 |
+
|
75 |
+
if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
|
76 |
+
print('log_one_over_alpha has nan')
|
77 |
+
if torch.sum(lgamma_beta != lgamma_beta) > 0:
|
78 |
+
print('lgamma_beta has nan')
|
79 |
+
if torch.sum(log_beta != log_beta) > 0:
|
80 |
+
print('log_beta has nan')
|
81 |
+
|
82 |
+
l = resi - log_one_over_alpha + lgamma_beta - log_beta
|
83 |
+
|
84 |
+
if self.reduction == 'mean':
|
85 |
+
return l.mean()
|
86 |
+
elif self.reduction == 'sum':
|
87 |
+
return l.sum()
|
88 |
+
else:
|
89 |
+
print('Reduction not supported')
|
90 |
+
return None
|
91 |
+
|
92 |
+
class TempCombLoss(nn.Module):
|
93 |
+
def __init__(
|
94 |
+
self, reduction='mean',
|
95 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
96 |
+
resi_min = 1e-4, resi_max=1e3
|
97 |
+
) -> None:
|
98 |
+
super(TempCombLoss, self).__init__()
|
99 |
+
self.reduction = reduction
|
100 |
+
self.alpha_eps = alpha_eps
|
101 |
+
self.beta_eps = beta_eps
|
102 |
+
self.resi_min = resi_min
|
103 |
+
self.resi_max = resi_max
|
104 |
+
|
105 |
+
self.L_GenGauss = GenGaussLoss(
|
106 |
+
reduction=self.reduction,
|
107 |
+
alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
|
108 |
+
resi_min=self.resi_min, resi_max=self.resi_max
|
109 |
+
)
|
110 |
+
self.L_l1 = nn.L1Loss(reduction=self.reduction)
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
|
115 |
+
T1: float, T2: float
|
116 |
+
):
|
117 |
+
l1 = self.L_l1(mean, target)
|
118 |
+
l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
|
119 |
+
l = T1*l1 + T2*l2
|
120 |
+
|
121 |
+
return l
|
122 |
+
|
123 |
+
|
124 |
+
# x1 = torch.randn(4,3,32,32)
|
125 |
+
# x2 = torch.rand(4,3,32,32)
|
126 |
+
# x3 = torch.rand(4,3,32,32)
|
127 |
+
# x4 = torch.randn(4,3,32,32)
|
128 |
+
|
129 |
+
# L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
130 |
+
# L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
131 |
+
# print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2))
|
networks_SRGAN.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as models
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
# __all__ = [
|
8 |
+
# "ResidualConvBlock",
|
9 |
+
# "Discriminator", "Generator",
|
10 |
+
# ]
|
11 |
+
|
12 |
+
|
13 |
+
class ResidualConvBlock(nn.Module):
|
14 |
+
"""Implements residual conv function.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
channels (int): Number of channels in the input image.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, channels: int) -> None:
|
21 |
+
super(ResidualConvBlock, self).__init__()
|
22 |
+
self.rcb = nn.Sequential(
|
23 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
24 |
+
nn.BatchNorm2d(channels),
|
25 |
+
nn.PReLU(),
|
26 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
27 |
+
nn.BatchNorm2d(channels),
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
identity = x
|
32 |
+
|
33 |
+
out = self.rcb(x)
|
34 |
+
out = torch.add(out, identity)
|
35 |
+
|
36 |
+
return out
|
37 |
+
|
38 |
+
|
39 |
+
class Discriminator(nn.Module):
|
40 |
+
def __init__(self) -> None:
|
41 |
+
super(Discriminator, self).__init__()
|
42 |
+
self.features = nn.Sequential(
|
43 |
+
# input size. (3) x 96 x 96
|
44 |
+
nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
45 |
+
nn.LeakyReLU(0.2, True),
|
46 |
+
# state size. (64) x 48 x 48
|
47 |
+
nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
|
48 |
+
nn.BatchNorm2d(64),
|
49 |
+
nn.LeakyReLU(0.2, True),
|
50 |
+
nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
|
51 |
+
nn.BatchNorm2d(128),
|
52 |
+
nn.LeakyReLU(0.2, True),
|
53 |
+
# state size. (128) x 24 x 24
|
54 |
+
nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
|
55 |
+
nn.BatchNorm2d(128),
|
56 |
+
nn.LeakyReLU(0.2, True),
|
57 |
+
nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
|
58 |
+
nn.BatchNorm2d(256),
|
59 |
+
nn.LeakyReLU(0.2, True),
|
60 |
+
# state size. (256) x 12 x 12
|
61 |
+
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
|
62 |
+
nn.BatchNorm2d(256),
|
63 |
+
nn.LeakyReLU(0.2, True),
|
64 |
+
nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
|
65 |
+
nn.BatchNorm2d(512),
|
66 |
+
nn.LeakyReLU(0.2, True),
|
67 |
+
# state size. (512) x 6 x 6
|
68 |
+
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
|
69 |
+
nn.BatchNorm2d(512),
|
70 |
+
nn.LeakyReLU(0.2, True),
|
71 |
+
)
|
72 |
+
|
73 |
+
self.classifier = nn.Sequential(
|
74 |
+
nn.Linear(512 * 6 * 6, 1024),
|
75 |
+
nn.LeakyReLU(0.2, True),
|
76 |
+
nn.Linear(1024, 1),
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x: Tensor) -> Tensor:
|
80 |
+
out = self.features(x)
|
81 |
+
out = torch.flatten(out, 1)
|
82 |
+
out = self.classifier(out)
|
83 |
+
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
class Generator(nn.Module):
|
88 |
+
def __init__(self) -> None:
|
89 |
+
super(Generator, self).__init__()
|
90 |
+
# First conv layer.
|
91 |
+
self.conv_block1 = nn.Sequential(
|
92 |
+
nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)),
|
93 |
+
nn.PReLU(),
|
94 |
+
)
|
95 |
+
|
96 |
+
# Features trunk blocks.
|
97 |
+
trunk = []
|
98 |
+
for _ in range(16):
|
99 |
+
trunk.append(ResidualConvBlock(64))
|
100 |
+
self.trunk = nn.Sequential(*trunk)
|
101 |
+
|
102 |
+
# Second conv layer.
|
103 |
+
self.conv_block2 = nn.Sequential(
|
104 |
+
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
105 |
+
nn.BatchNorm2d(64),
|
106 |
+
)
|
107 |
+
|
108 |
+
# Upscale conv block.
|
109 |
+
self.upsampling = nn.Sequential(
|
110 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
111 |
+
nn.PixelShuffle(2),
|
112 |
+
nn.PReLU(),
|
113 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
114 |
+
nn.PixelShuffle(2),
|
115 |
+
nn.PReLU(),
|
116 |
+
)
|
117 |
+
|
118 |
+
# Output layer.
|
119 |
+
self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4))
|
120 |
+
|
121 |
+
# Initialize neural network weights.
|
122 |
+
self._initialize_weights()
|
123 |
+
|
124 |
+
def forward(self, x: Tensor, dop=None) -> Tensor:
|
125 |
+
if not dop:
|
126 |
+
return self._forward_impl(x)
|
127 |
+
else:
|
128 |
+
return self._forward_w_dop_impl(x, dop)
|
129 |
+
|
130 |
+
# Support torch.script function.
|
131 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
132 |
+
out1 = self.conv_block1(x)
|
133 |
+
out = self.trunk(out1)
|
134 |
+
out2 = self.conv_block2(out)
|
135 |
+
out = torch.add(out1, out2)
|
136 |
+
out = self.upsampling(out)
|
137 |
+
out = self.conv_block3(out)
|
138 |
+
|
139 |
+
return out
|
140 |
+
|
141 |
+
def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor:
|
142 |
+
out1 = self.conv_block1(x)
|
143 |
+
out = self.trunk(out1)
|
144 |
+
out2 = F.dropout2d(self.conv_block2(out), p=dop)
|
145 |
+
out = torch.add(out1, out2)
|
146 |
+
out = self.upsampling(out)
|
147 |
+
out = self.conv_block3(out)
|
148 |
+
|
149 |
+
return out
|
150 |
+
|
151 |
+
def _initialize_weights(self) -> None:
|
152 |
+
for module in self.modules():
|
153 |
+
if isinstance(module, nn.Conv2d):
|
154 |
+
nn.init.kaiming_normal_(module.weight)
|
155 |
+
if module.bias is not None:
|
156 |
+
nn.init.constant_(module.bias, 0)
|
157 |
+
elif isinstance(module, nn.BatchNorm2d):
|
158 |
+
nn.init.constant_(module.weight, 1)
|
159 |
+
|
160 |
+
|
161 |
+
#### BayesCap
|
162 |
+
class BayesCap(nn.Module):
|
163 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
164 |
+
super(BayesCap, self).__init__()
|
165 |
+
# First conv layer.
|
166 |
+
self.conv_block1 = nn.Sequential(
|
167 |
+
nn.Conv2d(
|
168 |
+
in_channels, 64,
|
169 |
+
kernel_size=9, stride=1, padding=4
|
170 |
+
),
|
171 |
+
nn.PReLU(),
|
172 |
+
)
|
173 |
+
|
174 |
+
# Features trunk blocks.
|
175 |
+
trunk = []
|
176 |
+
for _ in range(16):
|
177 |
+
trunk.append(ResidualConvBlock(64))
|
178 |
+
self.trunk = nn.Sequential(*trunk)
|
179 |
+
|
180 |
+
# Second conv layer.
|
181 |
+
self.conv_block2 = nn.Sequential(
|
182 |
+
nn.Conv2d(
|
183 |
+
64, 64,
|
184 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
185 |
+
),
|
186 |
+
nn.BatchNorm2d(64),
|
187 |
+
)
|
188 |
+
|
189 |
+
# Output layer.
|
190 |
+
self.conv_block3_mu = nn.Conv2d(
|
191 |
+
64, out_channels=out_channels,
|
192 |
+
kernel_size=9, stride=1, padding=4
|
193 |
+
)
|
194 |
+
self.conv_block3_alpha = nn.Sequential(
|
195 |
+
nn.Conv2d(
|
196 |
+
64, 64,
|
197 |
+
kernel_size=9, stride=1, padding=4
|
198 |
+
),
|
199 |
+
nn.PReLU(),
|
200 |
+
nn.Conv2d(
|
201 |
+
64, 64,
|
202 |
+
kernel_size=9, stride=1, padding=4
|
203 |
+
),
|
204 |
+
nn.PReLU(),
|
205 |
+
nn.Conv2d(
|
206 |
+
64, 1,
|
207 |
+
kernel_size=9, stride=1, padding=4
|
208 |
+
),
|
209 |
+
nn.ReLU(),
|
210 |
+
)
|
211 |
+
self.conv_block3_beta = nn.Sequential(
|
212 |
+
nn.Conv2d(
|
213 |
+
64, 64,
|
214 |
+
kernel_size=9, stride=1, padding=4
|
215 |
+
),
|
216 |
+
nn.PReLU(),
|
217 |
+
nn.Conv2d(
|
218 |
+
64, 64,
|
219 |
+
kernel_size=9, stride=1, padding=4
|
220 |
+
),
|
221 |
+
nn.PReLU(),
|
222 |
+
nn.Conv2d(
|
223 |
+
64, 1,
|
224 |
+
kernel_size=9, stride=1, padding=4
|
225 |
+
),
|
226 |
+
nn.ReLU(),
|
227 |
+
)
|
228 |
+
|
229 |
+
# Initialize neural network weights.
|
230 |
+
self._initialize_weights()
|
231 |
+
|
232 |
+
def forward(self, x: Tensor) -> Tensor:
|
233 |
+
return self._forward_impl(x)
|
234 |
+
|
235 |
+
# Support torch.script function.
|
236 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
237 |
+
out1 = self.conv_block1(x)
|
238 |
+
out = self.trunk(out1)
|
239 |
+
out2 = self.conv_block2(out)
|
240 |
+
out = out1 + out2
|
241 |
+
out_mu = self.conv_block3_mu(out)
|
242 |
+
out_alpha = self.conv_block3_alpha(out)
|
243 |
+
out_beta = self.conv_block3_beta(out)
|
244 |
+
return out_mu, out_alpha, out_beta
|
245 |
+
|
246 |
+
def _initialize_weights(self) -> None:
|
247 |
+
for module in self.modules():
|
248 |
+
if isinstance(module, nn.Conv2d):
|
249 |
+
nn.init.kaiming_normal_(module.weight)
|
250 |
+
if module.bias is not None:
|
251 |
+
nn.init.constant_(module.bias, 0)
|
252 |
+
elif isinstance(module, nn.BatchNorm2d):
|
253 |
+
nn.init.constant_(module.weight, 1)
|
254 |
+
|
255 |
+
|
256 |
+
class BayesCap_noID(nn.Module):
|
257 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
258 |
+
super(BayesCap_noID, self).__init__()
|
259 |
+
# First conv layer.
|
260 |
+
self.conv_block1 = nn.Sequential(
|
261 |
+
nn.Conv2d(
|
262 |
+
in_channels, 64,
|
263 |
+
kernel_size=9, stride=1, padding=4
|
264 |
+
),
|
265 |
+
nn.PReLU(),
|
266 |
+
)
|
267 |
+
|
268 |
+
# Features trunk blocks.
|
269 |
+
trunk = []
|
270 |
+
for _ in range(16):
|
271 |
+
trunk.append(ResidualConvBlock(64))
|
272 |
+
self.trunk = nn.Sequential(*trunk)
|
273 |
+
|
274 |
+
# Second conv layer.
|
275 |
+
self.conv_block2 = nn.Sequential(
|
276 |
+
nn.Conv2d(
|
277 |
+
64, 64,
|
278 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
279 |
+
),
|
280 |
+
nn.BatchNorm2d(64),
|
281 |
+
)
|
282 |
+
|
283 |
+
# Output layer.
|
284 |
+
# self.conv_block3_mu = nn.Conv2d(
|
285 |
+
# 64, out_channels=out_channels,
|
286 |
+
# kernel_size=9, stride=1, padding=4
|
287 |
+
# )
|
288 |
+
self.conv_block3_alpha = nn.Sequential(
|
289 |
+
nn.Conv2d(
|
290 |
+
64, 64,
|
291 |
+
kernel_size=9, stride=1, padding=4
|
292 |
+
),
|
293 |
+
nn.PReLU(),
|
294 |
+
nn.Conv2d(
|
295 |
+
64, 64,
|
296 |
+
kernel_size=9, stride=1, padding=4
|
297 |
+
),
|
298 |
+
nn.PReLU(),
|
299 |
+
nn.Conv2d(
|
300 |
+
64, 1,
|
301 |
+
kernel_size=9, stride=1, padding=4
|
302 |
+
),
|
303 |
+
nn.ReLU(),
|
304 |
+
)
|
305 |
+
self.conv_block3_beta = nn.Sequential(
|
306 |
+
nn.Conv2d(
|
307 |
+
64, 64,
|
308 |
+
kernel_size=9, stride=1, padding=4
|
309 |
+
),
|
310 |
+
nn.PReLU(),
|
311 |
+
nn.Conv2d(
|
312 |
+
64, 64,
|
313 |
+
kernel_size=9, stride=1, padding=4
|
314 |
+
),
|
315 |
+
nn.PReLU(),
|
316 |
+
nn.Conv2d(
|
317 |
+
64, 1,
|
318 |
+
kernel_size=9, stride=1, padding=4
|
319 |
+
),
|
320 |
+
nn.ReLU(),
|
321 |
+
)
|
322 |
+
|
323 |
+
# Initialize neural network weights.
|
324 |
+
self._initialize_weights()
|
325 |
+
|
326 |
+
def forward(self, x: Tensor) -> Tensor:
|
327 |
+
return self._forward_impl(x)
|
328 |
+
|
329 |
+
# Support torch.script function.
|
330 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
331 |
+
out1 = self.conv_block1(x)
|
332 |
+
out = self.trunk(out1)
|
333 |
+
out2 = self.conv_block2(out)
|
334 |
+
out = out1 + out2
|
335 |
+
# out_mu = self.conv_block3_mu(out)
|
336 |
+
out_alpha = self.conv_block3_alpha(out)
|
337 |
+
out_beta = self.conv_block3_beta(out)
|
338 |
+
return out_alpha, out_beta
|
339 |
+
|
340 |
+
def _initialize_weights(self) -> None:
|
341 |
+
for module in self.modules():
|
342 |
+
if isinstance(module, nn.Conv2d):
|
343 |
+
nn.init.kaiming_normal_(module.weight)
|
344 |
+
if module.bias is not None:
|
345 |
+
nn.init.constant_(module.bias, 0)
|
346 |
+
elif isinstance(module, nn.BatchNorm2d):
|
347 |
+
nn.init.constant_(module.weight, 1)
|
networks_T1toT2.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import functools
|
5 |
+
|
6 |
+
### components
|
7 |
+
class ResConv(nn.Module):
|
8 |
+
"""
|
9 |
+
Residual convolutional block, where
|
10 |
+
convolutional block consists: (convolution => [BN] => ReLU) * 3
|
11 |
+
residual connection adds the input to the output
|
12 |
+
"""
|
13 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
14 |
+
super().__init__()
|
15 |
+
if not mid_channels:
|
16 |
+
mid_channels = out_channels
|
17 |
+
self.double_conv = nn.Sequential(
|
18 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
19 |
+
nn.BatchNorm2d(mid_channels),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
|
22 |
+
nn.BatchNorm2d(mid_channels),
|
23 |
+
nn.ReLU(inplace=True),
|
24 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
25 |
+
nn.BatchNorm2d(out_channels),
|
26 |
+
nn.ReLU(inplace=True)
|
27 |
+
)
|
28 |
+
self.double_conv1 = nn.Sequential(
|
29 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
30 |
+
nn.BatchNorm2d(out_channels),
|
31 |
+
nn.ReLU(inplace=True),
|
32 |
+
)
|
33 |
+
def forward(self, x):
|
34 |
+
x_in = self.double_conv1(x)
|
35 |
+
x1 = self.double_conv(x)
|
36 |
+
return self.double_conv(x) + x_in
|
37 |
+
|
38 |
+
class Down(nn.Module):
|
39 |
+
"""Downscaling with maxpool then Resconv"""
|
40 |
+
def __init__(self, in_channels, out_channels):
|
41 |
+
super().__init__()
|
42 |
+
self.maxpool_conv = nn.Sequential(
|
43 |
+
nn.MaxPool2d(2),
|
44 |
+
ResConv(in_channels, out_channels)
|
45 |
+
)
|
46 |
+
def forward(self, x):
|
47 |
+
return self.maxpool_conv(x)
|
48 |
+
|
49 |
+
class Up(nn.Module):
|
50 |
+
"""Upscaling then double conv"""
|
51 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
52 |
+
super().__init__()
|
53 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
54 |
+
if bilinear:
|
55 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
56 |
+
self.conv = ResConv(in_channels, out_channels, in_channels // 2)
|
57 |
+
else:
|
58 |
+
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
|
59 |
+
self.conv = ResConv(in_channels, out_channels)
|
60 |
+
def forward(self, x1, x2):
|
61 |
+
x1 = self.up(x1)
|
62 |
+
# input is CHW
|
63 |
+
diffY = x2.size()[2] - x1.size()[2]
|
64 |
+
diffX = x2.size()[3] - x1.size()[3]
|
65 |
+
x1 = F.pad(
|
66 |
+
x1,
|
67 |
+
[
|
68 |
+
diffX // 2, diffX - diffX // 2,
|
69 |
+
diffY // 2, diffY - diffY // 2
|
70 |
+
]
|
71 |
+
)
|
72 |
+
# if you have padding issues, see
|
73 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
74 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
75 |
+
x = torch.cat([x2, x1], dim=1)
|
76 |
+
return self.conv(x)
|
77 |
+
|
78 |
+
class OutConv(nn.Module):
|
79 |
+
def __init__(self, in_channels, out_channels):
|
80 |
+
super(OutConv, self).__init__()
|
81 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
82 |
+
def forward(self, x):
|
83 |
+
# return F.relu(self.conv(x))
|
84 |
+
return self.conv(x)
|
85 |
+
|
86 |
+
##### The composite networks
|
87 |
+
class UNet(nn.Module):
|
88 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
89 |
+
super(UNet, self).__init__()
|
90 |
+
self.n_channels = n_channels
|
91 |
+
self.out_channels = out_channels
|
92 |
+
self.bilinear = bilinear
|
93 |
+
####
|
94 |
+
self.inc = ResConv(n_channels, 64)
|
95 |
+
self.down1 = Down(64, 128)
|
96 |
+
self.down2 = Down(128, 256)
|
97 |
+
self.down3 = Down(256, 512)
|
98 |
+
factor = 2 if bilinear else 1
|
99 |
+
self.down4 = Down(512, 1024 // factor)
|
100 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
101 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
102 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
103 |
+
self.up4 = Up(128, 64, bilinear)
|
104 |
+
self.outc = OutConv(64, out_channels)
|
105 |
+
def forward(self, x):
|
106 |
+
x1 = self.inc(x)
|
107 |
+
x2 = self.down1(x1)
|
108 |
+
x3 = self.down2(x2)
|
109 |
+
x4 = self.down3(x3)
|
110 |
+
x5 = self.down4(x4)
|
111 |
+
x = self.up1(x5, x4)
|
112 |
+
x = self.up2(x, x3)
|
113 |
+
x = self.up3(x, x2)
|
114 |
+
x = self.up4(x, x1)
|
115 |
+
y = self.outc(x)
|
116 |
+
return y
|
117 |
+
|
118 |
+
class CasUNet(nn.Module):
|
119 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
120 |
+
super(CasUNet, self).__init__()
|
121 |
+
self.n_unet = n_unet
|
122 |
+
self.io_channels = io_channels
|
123 |
+
self.bilinear = bilinear
|
124 |
+
####
|
125 |
+
self.unet_list = nn.ModuleList()
|
126 |
+
for i in range(self.n_unet):
|
127 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
128 |
+
def forward(self, x, dop=None):
|
129 |
+
y = x
|
130 |
+
for i in range(self.n_unet):
|
131 |
+
if i==0:
|
132 |
+
if dop is not None:
|
133 |
+
y = F.dropout2d(self.unet_list[i](y), p=dop)
|
134 |
+
else:
|
135 |
+
y = self.unet_list[i](y)
|
136 |
+
else:
|
137 |
+
y = self.unet_list[i](y+x)
|
138 |
+
return y
|
139 |
+
|
140 |
+
class CasUNet_2head(nn.Module):
|
141 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
142 |
+
super(CasUNet_2head, self).__init__()
|
143 |
+
self.n_unet = n_unet
|
144 |
+
self.io_channels = io_channels
|
145 |
+
self.bilinear = bilinear
|
146 |
+
####
|
147 |
+
self.unet_list = nn.ModuleList()
|
148 |
+
for i in range(self.n_unet):
|
149 |
+
if i != self.n_unet-1:
|
150 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
151 |
+
else:
|
152 |
+
self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear))
|
153 |
+
def forward(self, x):
|
154 |
+
y = x
|
155 |
+
for i in range(self.n_unet):
|
156 |
+
if i==0:
|
157 |
+
y = self.unet_list[i](y)
|
158 |
+
else:
|
159 |
+
y = self.unet_list[i](y+x)
|
160 |
+
y_mean, y_sigma = y[0], y[1]
|
161 |
+
return y_mean, y_sigma
|
162 |
+
|
163 |
+
class CasUNet_3head(nn.Module):
|
164 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
165 |
+
super(CasUNet_3head, self).__init__()
|
166 |
+
self.n_unet = n_unet
|
167 |
+
self.io_channels = io_channels
|
168 |
+
self.bilinear = bilinear
|
169 |
+
####
|
170 |
+
self.unet_list = nn.ModuleList()
|
171 |
+
for i in range(self.n_unet):
|
172 |
+
if i != self.n_unet-1:
|
173 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
174 |
+
else:
|
175 |
+
self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear))
|
176 |
+
def forward(self, x):
|
177 |
+
y = x
|
178 |
+
for i in range(self.n_unet):
|
179 |
+
if i==0:
|
180 |
+
y = self.unet_list[i](y)
|
181 |
+
else:
|
182 |
+
y = self.unet_list[i](y+x)
|
183 |
+
y_mean, y_alpha, y_beta = y[0], y[1], y[2]
|
184 |
+
return y_mean, y_alpha, y_beta
|
185 |
+
|
186 |
+
class UNet_2head(nn.Module):
|
187 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
188 |
+
super(UNet_2head, self).__init__()
|
189 |
+
self.n_channels = n_channels
|
190 |
+
self.out_channels = out_channels
|
191 |
+
self.bilinear = bilinear
|
192 |
+
####
|
193 |
+
self.inc = ResConv(n_channels, 64)
|
194 |
+
self.down1 = Down(64, 128)
|
195 |
+
self.down2 = Down(128, 256)
|
196 |
+
self.down3 = Down(256, 512)
|
197 |
+
factor = 2 if bilinear else 1
|
198 |
+
self.down4 = Down(512, 1024 // factor)
|
199 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
200 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
201 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
202 |
+
self.up4 = Up(128, 64, bilinear)
|
203 |
+
#per pixel multiple channels may exist
|
204 |
+
self.out_mean = OutConv(64, out_channels)
|
205 |
+
#variance will always be a single number for a pixel
|
206 |
+
self.out_var = nn.Sequential(
|
207 |
+
OutConv(64, 128),
|
208 |
+
OutConv(128, 1),
|
209 |
+
)
|
210 |
+
def forward(self, x):
|
211 |
+
x1 = self.inc(x)
|
212 |
+
x2 = self.down1(x1)
|
213 |
+
x3 = self.down2(x2)
|
214 |
+
x4 = self.down3(x3)
|
215 |
+
x5 = self.down4(x4)
|
216 |
+
x = self.up1(x5, x4)
|
217 |
+
x = self.up2(x, x3)
|
218 |
+
x = self.up3(x, x2)
|
219 |
+
x = self.up4(x, x1)
|
220 |
+
y_mean, y_var = self.out_mean(x), self.out_var(x)
|
221 |
+
return y_mean, y_var
|
222 |
+
|
223 |
+
class UNet_3head(nn.Module):
|
224 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
225 |
+
super(UNet_3head, self).__init__()
|
226 |
+
self.n_channels = n_channels
|
227 |
+
self.out_channels = out_channels
|
228 |
+
self.bilinear = bilinear
|
229 |
+
####
|
230 |
+
self.inc = ResConv(n_channels, 64)
|
231 |
+
self.down1 = Down(64, 128)
|
232 |
+
self.down2 = Down(128, 256)
|
233 |
+
self.down3 = Down(256, 512)
|
234 |
+
factor = 2 if bilinear else 1
|
235 |
+
self.down4 = Down(512, 1024 // factor)
|
236 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
237 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
238 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
239 |
+
self.up4 = Up(128, 64, bilinear)
|
240 |
+
#per pixel multiple channels may exist
|
241 |
+
self.out_mean = OutConv(64, out_channels)
|
242 |
+
#variance will always be a single number for a pixel
|
243 |
+
self.out_alpha = nn.Sequential(
|
244 |
+
OutConv(64, 128),
|
245 |
+
OutConv(128, 1),
|
246 |
+
nn.ReLU()
|
247 |
+
)
|
248 |
+
self.out_beta = nn.Sequential(
|
249 |
+
OutConv(64, 128),
|
250 |
+
OutConv(128, 1),
|
251 |
+
nn.ReLU()
|
252 |
+
)
|
253 |
+
def forward(self, x):
|
254 |
+
x1 = self.inc(x)
|
255 |
+
x2 = self.down1(x1)
|
256 |
+
x3 = self.down2(x2)
|
257 |
+
x4 = self.down3(x3)
|
258 |
+
x5 = self.down4(x4)
|
259 |
+
x = self.up1(x5, x4)
|
260 |
+
x = self.up2(x, x3)
|
261 |
+
x = self.up3(x, x2)
|
262 |
+
x = self.up4(x, x1)
|
263 |
+
y_mean, y_alpha, y_beta = self.out_mean(x), \
|
264 |
+
self.out_alpha(x), self.out_beta(x)
|
265 |
+
return y_mean, y_alpha, y_beta
|
266 |
+
|
267 |
+
class ResidualBlock(nn.Module):
|
268 |
+
def __init__(self, in_features):
|
269 |
+
super(ResidualBlock, self).__init__()
|
270 |
+
conv_block = [
|
271 |
+
nn.ReflectionPad2d(1),
|
272 |
+
nn.Conv2d(in_features, in_features, 3),
|
273 |
+
nn.InstanceNorm2d(in_features),
|
274 |
+
nn.ReLU(inplace=True),
|
275 |
+
nn.ReflectionPad2d(1),
|
276 |
+
nn.Conv2d(in_features, in_features, 3),
|
277 |
+
nn.InstanceNorm2d(in_features)
|
278 |
+
]
|
279 |
+
self.conv_block = nn.Sequential(*conv_block)
|
280 |
+
def forward(self, x):
|
281 |
+
return x + self.conv_block(x)
|
282 |
+
|
283 |
+
class Generator(nn.Module):
|
284 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
|
285 |
+
super(Generator, self).__init__()
|
286 |
+
# Initial convolution block
|
287 |
+
model = [
|
288 |
+
nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7),
|
289 |
+
nn.InstanceNorm2d(64), nn.ReLU(inplace=True)
|
290 |
+
]
|
291 |
+
# Downsampling
|
292 |
+
in_features = 64
|
293 |
+
out_features = in_features*2
|
294 |
+
for _ in range(2):
|
295 |
+
model += [
|
296 |
+
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
297 |
+
nn.InstanceNorm2d(out_features),
|
298 |
+
nn.ReLU(inplace=True)
|
299 |
+
]
|
300 |
+
in_features = out_features
|
301 |
+
out_features = in_features*2
|
302 |
+
# Residual blocks
|
303 |
+
for _ in range(n_residual_blocks):
|
304 |
+
model += [ResidualBlock(in_features)]
|
305 |
+
# Upsampling
|
306 |
+
out_features = in_features//2
|
307 |
+
for _ in range(2):
|
308 |
+
model += [
|
309 |
+
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
310 |
+
nn.InstanceNorm2d(out_features),
|
311 |
+
nn.ReLU(inplace=True)
|
312 |
+
]
|
313 |
+
in_features = out_features
|
314 |
+
out_features = in_features//2
|
315 |
+
# Output layer
|
316 |
+
model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]
|
317 |
+
self.model = nn.Sequential(*model)
|
318 |
+
def forward(self, x):
|
319 |
+
return self.model(x)
|
320 |
+
|
321 |
+
|
322 |
+
class ResnetGenerator(nn.Module):
|
323 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
324 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
325 |
+
"""
|
326 |
+
|
327 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
328 |
+
"""Construct a Resnet-based generator
|
329 |
+
Parameters:
|
330 |
+
input_nc (int) -- the number of channels in input images
|
331 |
+
output_nc (int) -- the number of channels in output images
|
332 |
+
ngf (int) -- the number of filters in the last conv layer
|
333 |
+
norm_layer -- normalization layer
|
334 |
+
use_dropout (bool) -- if use dropout layers
|
335 |
+
n_blocks (int) -- the number of ResNet blocks
|
336 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
337 |
+
"""
|
338 |
+
assert(n_blocks >= 0)
|
339 |
+
super(ResnetGenerator, self).__init__()
|
340 |
+
if type(norm_layer) == functools.partial:
|
341 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
342 |
+
else:
|
343 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
344 |
+
|
345 |
+
model = [nn.ReflectionPad2d(3),
|
346 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
347 |
+
norm_layer(ngf),
|
348 |
+
nn.ReLU(True)]
|
349 |
+
|
350 |
+
n_downsampling = 2
|
351 |
+
for i in range(n_downsampling): # add downsampling layers
|
352 |
+
mult = 2 ** i
|
353 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
354 |
+
norm_layer(ngf * mult * 2),
|
355 |
+
nn.ReLU(True)]
|
356 |
+
|
357 |
+
mult = 2 ** n_downsampling
|
358 |
+
for i in range(n_blocks): # add ResNet blocks
|
359 |
+
|
360 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
361 |
+
|
362 |
+
for i in range(n_downsampling): # add upsampling layers
|
363 |
+
mult = 2 ** (n_downsampling - i)
|
364 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
365 |
+
kernel_size=3, stride=2,
|
366 |
+
padding=1, output_padding=1,
|
367 |
+
bias=use_bias),
|
368 |
+
norm_layer(int(ngf * mult / 2)),
|
369 |
+
nn.ReLU(True)]
|
370 |
+
model += [nn.ReflectionPad2d(3)]
|
371 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
372 |
+
model += [nn.Tanh()]
|
373 |
+
|
374 |
+
self.model = nn.Sequential(*model)
|
375 |
+
|
376 |
+
def forward(self, input):
|
377 |
+
"""Standard forward"""
|
378 |
+
return self.model(input)
|
379 |
+
|
380 |
+
|
381 |
+
class ResnetBlock(nn.Module):
|
382 |
+
"""Define a Resnet block"""
|
383 |
+
|
384 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
385 |
+
"""Initialize the Resnet block
|
386 |
+
A resnet block is a conv block with skip connections
|
387 |
+
We construct a conv block with build_conv_block function,
|
388 |
+
and implement skip connections in <forward> function.
|
389 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
390 |
+
"""
|
391 |
+
super(ResnetBlock, self).__init__()
|
392 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
393 |
+
|
394 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
395 |
+
"""Construct a convolutional block.
|
396 |
+
Parameters:
|
397 |
+
dim (int) -- the number of channels in the conv layer.
|
398 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
399 |
+
norm_layer -- normalization layer
|
400 |
+
use_dropout (bool) -- if use dropout layers.
|
401 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
402 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
403 |
+
"""
|
404 |
+
conv_block = []
|
405 |
+
p = 0
|
406 |
+
if padding_type == 'reflect':
|
407 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
408 |
+
elif padding_type == 'replicate':
|
409 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
410 |
+
elif padding_type == 'zero':
|
411 |
+
p = 1
|
412 |
+
else:
|
413 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
414 |
+
|
415 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
416 |
+
if use_dropout:
|
417 |
+
conv_block += [nn.Dropout(0.5)]
|
418 |
+
|
419 |
+
p = 0
|
420 |
+
if padding_type == 'reflect':
|
421 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
422 |
+
elif padding_type == 'replicate':
|
423 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
424 |
+
elif padding_type == 'zero':
|
425 |
+
p = 1
|
426 |
+
else:
|
427 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
428 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
429 |
+
|
430 |
+
return nn.Sequential(*conv_block)
|
431 |
+
|
432 |
+
def forward(self, x):
|
433 |
+
"""Forward function (with skip connections)"""
|
434 |
+
out = x + self.conv_block(x) # add skip connections
|
435 |
+
return out
|
436 |
+
|
437 |
+
### discriminator
|
438 |
+
class NLayerDiscriminator(nn.Module):
|
439 |
+
"""Defines a PatchGAN discriminator"""
|
440 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
441 |
+
"""Construct a PatchGAN discriminator
|
442 |
+
Parameters:
|
443 |
+
input_nc (int) -- the number of channels in input images
|
444 |
+
ndf (int) -- the number of filters in the last conv layer
|
445 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
446 |
+
norm_layer -- normalization layer
|
447 |
+
"""
|
448 |
+
super(NLayerDiscriminator, self).__init__()
|
449 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
450 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
451 |
+
else:
|
452 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
453 |
+
kw = 4
|
454 |
+
padw = 1
|
455 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
456 |
+
nf_mult = 1
|
457 |
+
nf_mult_prev = 1
|
458 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
459 |
+
nf_mult_prev = nf_mult
|
460 |
+
nf_mult = min(2 ** n, 8)
|
461 |
+
sequence += [
|
462 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
463 |
+
norm_layer(ndf * nf_mult),
|
464 |
+
nn.LeakyReLU(0.2, True)
|
465 |
+
]
|
466 |
+
nf_mult_prev = nf_mult
|
467 |
+
nf_mult = min(2 ** n_layers, 8)
|
468 |
+
sequence += [
|
469 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
470 |
+
norm_layer(ndf * nf_mult),
|
471 |
+
nn.LeakyReLU(0.2, True)
|
472 |
+
]
|
473 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
474 |
+
self.model = nn.Sequential(*sequence)
|
475 |
+
def forward(self, input):
|
476 |
+
"""Standard forward."""
|
477 |
+
return self.model(input)
|
requirements.txt
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may be used to create an environment using:
|
2 |
+
# $ conda create --name <env> --file <this file>
|
3 |
+
# platform: linux-64
|
4 |
+
_libgcc_mutex=0.1=conda_forge
|
5 |
+
_openmp_mutex=4.5=2_kmp_llvm
|
6 |
+
aiohttp=3.8.1=pypi_0
|
7 |
+
aiosignal=1.2.0=pypi_0
|
8 |
+
albumentations=1.2.0=pyhd8ed1ab_0
|
9 |
+
alsa-lib=1.2.6.1=h7f98852_0
|
10 |
+
analytics-python=1.4.0=pypi_0
|
11 |
+
anyio=3.6.1=pypi_0
|
12 |
+
aom=3.3.0=h27087fc_1
|
13 |
+
argon2-cffi=21.3.0=pypi_0
|
14 |
+
argon2-cffi-bindings=21.2.0=pypi_0
|
15 |
+
asttokens=2.0.5=pypi_0
|
16 |
+
async-timeout=4.0.2=pypi_0
|
17 |
+
attr=2.5.1=h166bdaf_0
|
18 |
+
attrs=21.4.0=pypi_0
|
19 |
+
babel=2.10.1=pypi_0
|
20 |
+
backcall=0.2.0=pypi_0
|
21 |
+
backoff=1.10.0=pypi_0
|
22 |
+
bcrypt=3.2.2=pypi_0
|
23 |
+
beautifulsoup4=4.11.1=pypi_0
|
24 |
+
blas=1.0=mkl
|
25 |
+
bleach=5.0.0=pypi_0
|
26 |
+
blosc=1.21.1=h83bc5f7_3
|
27 |
+
brotli=1.0.9=h166bdaf_7
|
28 |
+
brotli-bin=1.0.9=h166bdaf_7
|
29 |
+
brotlipy=0.7.0=py310h7f8727e_1002
|
30 |
+
brunsli=0.1=h9c3ff4c_0
|
31 |
+
bzip2=1.0.8=h7b6447c_0
|
32 |
+
c-ares=1.18.1=h7f98852_0
|
33 |
+
c-blosc2=2.2.0=h7a311fb_0
|
34 |
+
ca-certificates=2022.6.15=ha878542_0
|
35 |
+
cairo=1.16.0=ha61ee94_1011
|
36 |
+
certifi=2022.6.15=py310hff52083_0
|
37 |
+
cffi=1.15.0=py310hd667e15_1
|
38 |
+
cfitsio=4.1.0=hd9d235c_0
|
39 |
+
charls=2.3.4=h9c3ff4c_0
|
40 |
+
charset-normalizer=2.0.4=pyhd3eb1b0_0
|
41 |
+
click=8.1.3=pypi_0
|
42 |
+
cloudpickle=2.1.0=pyhd8ed1ab_0
|
43 |
+
cryptography=37.0.1=py310h9ce1e76_0
|
44 |
+
cudatoolkit=10.2.89=hfd86e86_1
|
45 |
+
cycler=0.11.0=pypi_0
|
46 |
+
cytoolz=0.11.2=py310h5764c6d_2
|
47 |
+
dask-core=2022.7.0=pyhd8ed1ab_0
|
48 |
+
dbus=1.13.6=h5008d03_3
|
49 |
+
debugpy=1.6.0=pypi_0
|
50 |
+
decorator=5.1.1=pypi_0
|
51 |
+
defusedxml=0.7.1=pypi_0
|
52 |
+
entrypoints=0.4=pypi_0
|
53 |
+
executing=0.8.3=pypi_0
|
54 |
+
expat=2.4.8=h27087fc_0
|
55 |
+
fastapi=0.78.0=pypi_0
|
56 |
+
fastjsonschema=2.15.3=pypi_0
|
57 |
+
ffmpeg=4.4.2=habc3f16_0
|
58 |
+
ffmpy=0.3.0=pypi_0
|
59 |
+
fftw=3.3.10=nompi_h77c792f_102
|
60 |
+
fire=0.4.0=pypi_0
|
61 |
+
font-ttf-dejavu-sans-mono=2.37=hab24e00_0
|
62 |
+
font-ttf-inconsolata=3.000=h77eed37_0
|
63 |
+
font-ttf-source-code-pro=2.038=h77eed37_0
|
64 |
+
font-ttf-ubuntu=0.83=hab24e00_0
|
65 |
+
fontconfig=2.14.0=h8e229c2_0
|
66 |
+
fonts-conda-ecosystem=1=0
|
67 |
+
fonts-conda-forge=1=0
|
68 |
+
fonttools=4.33.3=pypi_0
|
69 |
+
freeglut=3.2.2=h9c3ff4c_1
|
70 |
+
freetype=2.11.0=h70c0345_0
|
71 |
+
frozenlist=1.3.0=pypi_0
|
72 |
+
fsspec=2022.5.0=pyhd8ed1ab_0
|
73 |
+
ftfy=6.1.1=pypi_0
|
74 |
+
gettext=0.19.8.1=h73d1719_1008
|
75 |
+
giflib=5.2.1=h7b6447c_0
|
76 |
+
glib=2.70.2=h780b84a_4
|
77 |
+
glib-tools=2.70.2=h780b84a_4
|
78 |
+
gmp=6.2.1=h295c915_3
|
79 |
+
gnutls=3.7.6=hb5d6004_1
|
80 |
+
gradio=3.0.24=pypi_0
|
81 |
+
graphite2=1.3.13=h58526e2_1001
|
82 |
+
gst-plugins-base=1.20.3=hf6a322e_0
|
83 |
+
gstreamer=1.20.3=hd4edc92_0
|
84 |
+
h11=0.12.0=pypi_0
|
85 |
+
harfbuzz=4.4.1=hf9f4e7c_0
|
86 |
+
hdf5=1.12.1=nompi_h2386368_104
|
87 |
+
httpcore=0.15.0=pypi_0
|
88 |
+
httpx=0.23.0=pypi_0
|
89 |
+
icu=70.1=h27087fc_0
|
90 |
+
idna=3.3=pyhd3eb1b0_0
|
91 |
+
imagecodecs=2022.2.22=py310h3ac3b6e_6
|
92 |
+
imageio=2.19.3=pyhcf75d05_0
|
93 |
+
intel-openmp=2021.4.0=h06a4308_3561
|
94 |
+
ipykernel=6.13.0=pypi_0
|
95 |
+
ipython=8.4.0=pypi_0
|
96 |
+
ipython-genutils=0.2.0=pypi_0
|
97 |
+
jack=1.9.18=h8c3723f_1002
|
98 |
+
jasper=2.0.33=ha77e612_0
|
99 |
+
jedi=0.18.1=pypi_0
|
100 |
+
jinja2=3.1.2=pypi_0
|
101 |
+
joblib=1.1.0=pyhd8ed1ab_0
|
102 |
+
jpeg=9e=h7f8727e_0
|
103 |
+
json5=0.9.8=pypi_0
|
104 |
+
jsonschema=4.6.0=pypi_0
|
105 |
+
jupyter-client=7.3.1=pypi_0
|
106 |
+
jupyter-core=4.10.0=pypi_0
|
107 |
+
jupyter-server=1.17.0=pypi_0
|
108 |
+
jupyterlab=3.4.2=pypi_0
|
109 |
+
jupyterlab-pygments=0.2.2=pypi_0
|
110 |
+
jupyterlab-server=2.14.0=pypi_0
|
111 |
+
jxrlib=1.1=h7f98852_2
|
112 |
+
keyutils=1.6.1=h166bdaf_0
|
113 |
+
kiwisolver=1.4.2=pypi_0
|
114 |
+
kornia=0.6.5=pypi_0
|
115 |
+
krb5=1.19.3=h3790be6_0
|
116 |
+
lame=3.100=h7b6447c_0
|
117 |
+
lcms2=2.12=h3be6417_0
|
118 |
+
ld_impl_linux-64=2.38=h1181459_1
|
119 |
+
lerc=3.0=h9c3ff4c_0
|
120 |
+
libaec=1.0.6=h9c3ff4c_0
|
121 |
+
libavif=0.10.1=h166bdaf_0
|
122 |
+
libblas=3.9.0=12_linux64_mkl
|
123 |
+
libbrotlicommon=1.0.9=h166bdaf_7
|
124 |
+
libbrotlidec=1.0.9=h166bdaf_7
|
125 |
+
libbrotlienc=1.0.9=h166bdaf_7
|
126 |
+
libcap=2.64=ha37c62d_0
|
127 |
+
libcblas=3.9.0=12_linux64_mkl
|
128 |
+
libclang=14.0.6=default_h2e3cab8_0
|
129 |
+
libclang13=14.0.6=default_h3a83d3e_0
|
130 |
+
libcups=2.3.3=hf5a7f15_1
|
131 |
+
libcurl=7.83.1=h7bff187_0
|
132 |
+
libdb=6.2.32=h9c3ff4c_0
|
133 |
+
libdeflate=1.12=h166bdaf_0
|
134 |
+
libdrm=2.4.112=h166bdaf_0
|
135 |
+
libedit=3.1.20191231=he28a2e2_2
|
136 |
+
libev=4.33=h516909a_1
|
137 |
+
libevent=2.1.10=h9b69904_4
|
138 |
+
libffi=3.4.2=h7f98852_5
|
139 |
+
libflac=1.3.4=h27087fc_0
|
140 |
+
libgcc-ng=12.1.0=h8d9b700_16
|
141 |
+
libgfortran-ng=12.1.0=h69a702a_16
|
142 |
+
libgfortran5=12.1.0=hdcd56e2_16
|
143 |
+
libglib=2.70.2=h174f98d_4
|
144 |
+
libglu=9.0.0=he1b5a44_1001
|
145 |
+
libiconv=1.16=h7f8727e_2
|
146 |
+
libidn2=2.3.2=h7f8727e_0
|
147 |
+
liblapack=3.9.0=12_linux64_mkl
|
148 |
+
liblapacke=3.9.0=12_linux64_mkl
|
149 |
+
libllvm14=14.0.6=he0ac6c6_0
|
150 |
+
libnghttp2=1.47.0=h727a467_0
|
151 |
+
libnsl=2.0.0=h7f98852_0
|
152 |
+
libogg=1.3.4=h7f98852_1
|
153 |
+
libopencv=4.5.5=py310hcb97b83_13
|
154 |
+
libopus=1.3.1=h7f98852_1
|
155 |
+
libpciaccess=0.16=h516909a_0
|
156 |
+
libpng=1.6.37=hbc83047_0
|
157 |
+
libpq=14.4=hd77ab85_0
|
158 |
+
libprotobuf=3.20.1=h6239696_0
|
159 |
+
libsndfile=1.0.31=h9c3ff4c_1
|
160 |
+
libssh2=1.10.0=ha56f1ee_2
|
161 |
+
libstdcxx-ng=12.1.0=ha89aaad_16
|
162 |
+
libtasn1=4.16.0=h27cfd23_0
|
163 |
+
libtiff=4.4.0=hc85c160_1
|
164 |
+
libtool=2.4.6=h9c3ff4c_1008
|
165 |
+
libudev1=249=h166bdaf_4
|
166 |
+
libunistring=0.9.10=h27cfd23_0
|
167 |
+
libuuid=2.32.1=h7f98852_1000
|
168 |
+
libuv=1.40.0=h7b6447c_0
|
169 |
+
libva=2.15.0=h166bdaf_0
|
170 |
+
libvorbis=1.3.7=h9c3ff4c_0
|
171 |
+
libvpx=1.11.0=h9c3ff4c_3
|
172 |
+
libwebp=1.2.2=h55f646e_0
|
173 |
+
libwebp-base=1.2.2=h7f8727e_0
|
174 |
+
libxcb=1.13=h7f98852_1004
|
175 |
+
libxkbcommon=1.0.3=he3ba5ed_0
|
176 |
+
libxml2=2.9.14=h22db469_3
|
177 |
+
libzlib=1.2.12=h166bdaf_1
|
178 |
+
libzopfli=1.0.3=h9c3ff4c_0
|
179 |
+
linkify-it-py=1.0.3=pypi_0
|
180 |
+
llvm-openmp=14.0.4=he0ac6c6_0
|
181 |
+
locket=1.0.0=pyhd8ed1ab_0
|
182 |
+
lz4-c=1.9.3=h295c915_1
|
183 |
+
markdown-it-py=2.1.0=pypi_0
|
184 |
+
markupsafe=2.1.1=pypi_0
|
185 |
+
matplotlib=3.5.2=pypi_0
|
186 |
+
matplotlib-inline=0.1.3=pypi_0
|
187 |
+
mdit-py-plugins=0.3.0=pypi_0
|
188 |
+
mdurl=0.1.1=pypi_0
|
189 |
+
mistune=0.8.4=pypi_0
|
190 |
+
mkl=2021.4.0=h06a4308_640
|
191 |
+
mkl-service=2.4.0=py310h7f8727e_0
|
192 |
+
mkl_fft=1.3.1=py310hd6ae3a3_0
|
193 |
+
mkl_random=1.2.2=py310h00e6091_0
|
194 |
+
mltk=0.0.5=pypi_0
|
195 |
+
monotonic=1.6=pypi_0
|
196 |
+
multidict=6.0.2=pypi_0
|
197 |
+
munch=2.5.0=pypi_0
|
198 |
+
mysql-common=8.0.29=haf5c9bc_1
|
199 |
+
mysql-libs=8.0.29=h28c427c_1
|
200 |
+
nbclassic=0.3.7=pypi_0
|
201 |
+
nbclient=0.6.4=pypi_0
|
202 |
+
nbconvert=6.5.0=pypi_0
|
203 |
+
nbformat=5.4.0=pypi_0
|
204 |
+
ncurses=6.3=h7f8727e_2
|
205 |
+
nest-asyncio=1.5.5=pypi_0
|
206 |
+
nettle=3.7.3=hbbd107a_1
|
207 |
+
networkx=2.8.4=pyhd8ed1ab_0
|
208 |
+
nltk=3.7=pypi_0
|
209 |
+
notebook=6.4.11=pypi_0
|
210 |
+
notebook-shim=0.1.0=pypi_0
|
211 |
+
nspr=4.32=h9c3ff4c_1
|
212 |
+
nss=3.78=h2350873_0
|
213 |
+
ntk=1.1.3=pypi_0
|
214 |
+
numpy=1.22.3=py310hfa59a62_0
|
215 |
+
numpy-base=1.22.3=py310h9585f30_0
|
216 |
+
opencv=4.5.5=py310hff52083_13
|
217 |
+
opencv-python=4.6.0.66=pypi_0
|
218 |
+
openh264=2.1.1=h4ff587b_0
|
219 |
+
openjpeg=2.4.0=hb52868f_1
|
220 |
+
openssl=1.1.1q=h166bdaf_0
|
221 |
+
orjson=3.7.7=pypi_0
|
222 |
+
packaging=21.3=pyhd8ed1ab_0
|
223 |
+
pandas=1.4.2=pypi_0
|
224 |
+
pandocfilters=1.5.0=pypi_0
|
225 |
+
paramiko=2.11.0=pypi_0
|
226 |
+
parso=0.8.3=pypi_0
|
227 |
+
partd=1.2.0=pyhd8ed1ab_0
|
228 |
+
pcre=8.45=h9c3ff4c_0
|
229 |
+
pexpect=4.8.0=pypi_0
|
230 |
+
pickleshare=0.7.5=pypi_0
|
231 |
+
pillow=9.0.1=py310h22f2fdc_0
|
232 |
+
pip=21.2.4=py310h06a4308_0
|
233 |
+
pixman=0.40.0=h36c2ea0_0
|
234 |
+
portaudio=19.6.0=h57a0ea0_5
|
235 |
+
prometheus-client=0.14.1=pypi_0
|
236 |
+
prompt-toolkit=3.0.29=pypi_0
|
237 |
+
psutil=5.9.1=pypi_0
|
238 |
+
pthread-stubs=0.4=h36c2ea0_1001
|
239 |
+
ptyprocess=0.7.0=pypi_0
|
240 |
+
pulseaudio=14.0=h7f54b18_8
|
241 |
+
pure-eval=0.2.2=pypi_0
|
242 |
+
py-opencv=4.5.5=py310hfdc917e_13
|
243 |
+
pycocotools=2.0.4=pypi_0
|
244 |
+
pycparser=2.21=pyhd3eb1b0_0
|
245 |
+
pycryptodome=3.15.0=pypi_0
|
246 |
+
pydantic=1.9.1=pypi_0
|
247 |
+
pydub=0.25.1=pypi_0
|
248 |
+
pygments=2.12.0=pypi_0
|
249 |
+
pynacl=1.5.0=pypi_0
|
250 |
+
pyopenssl=22.0.0=pyhd3eb1b0_0
|
251 |
+
pyparsing=3.0.9=pyhd8ed1ab_0
|
252 |
+
pyrsistent=0.18.1=pypi_0
|
253 |
+
pysocks=1.7.1=py310h06a4308_0
|
254 |
+
python=3.10.5=h582c2e5_0_cpython
|
255 |
+
python-dateutil=2.8.2=pypi_0
|
256 |
+
python-multipart=0.0.5=pypi_0
|
257 |
+
python_abi=3.10=2_cp310
|
258 |
+
pytorch=1.11.0=py3.10_cuda10.2_cudnn7.6.5_0
|
259 |
+
pytorch-mutex=1.0=cuda
|
260 |
+
pytz=2022.1=pypi_0
|
261 |
+
pywavelets=1.3.0=py310hde88566_1
|
262 |
+
pyyaml=6.0=py310h5764c6d_4
|
263 |
+
pyzmq=23.1.0=pypi_0
|
264 |
+
qt-main=5.15.4=ha5833f6_2
|
265 |
+
qudida=0.0.4=pyhd8ed1ab_0
|
266 |
+
readline=8.1.2=h7f8727e_1
|
267 |
+
regex=2022.6.2=pypi_0
|
268 |
+
requests=2.27.1=pyhd3eb1b0_0
|
269 |
+
rfc3986=1.5.0=pypi_0
|
270 |
+
scikit-image=0.19.3=py310h769672d_0
|
271 |
+
scikit-learn=1.1.1=py310hffb9edd_0
|
272 |
+
scipy=1.8.1=py310h7612f91_0
|
273 |
+
seaborn=0.11.2=pypi_0
|
274 |
+
send2trash=1.8.0=pypi_0
|
275 |
+
setuptools=61.2.0=py310h06a4308_0
|
276 |
+
six=1.16.0=pyhd3eb1b0_1
|
277 |
+
snappy=1.1.9=hbd366e4_1
|
278 |
+
sniffio=1.2.0=pypi_0
|
279 |
+
soupsieve=2.3.2.post1=pypi_0
|
280 |
+
sqlite=3.39.0=h4ff8645_0
|
281 |
+
stack-data=0.2.0=pypi_0
|
282 |
+
starlette=0.19.1=pypi_0
|
283 |
+
svt-av1=1.1.0=h27087fc_1
|
284 |
+
termcolor=1.1.0=pypi_0
|
285 |
+
terminado=0.15.0=pypi_0
|
286 |
+
threadpoolctl=3.1.0=pyh8a188c0_0
|
287 |
+
tifffile=2022.5.4=pyhd8ed1ab_0
|
288 |
+
tinycss2=1.1.1=pypi_0
|
289 |
+
tk=8.6.12=h1ccaba5_0
|
290 |
+
toolz=0.11.2=pyhd8ed1ab_0
|
291 |
+
torchaudio=0.11.0=py310_cu102
|
292 |
+
torchvision=0.12.0=py310_cu102
|
293 |
+
tornado=6.1=pypi_0
|
294 |
+
tqdm=4.64.0=pypi_0
|
295 |
+
traitlets=5.2.2.post1=pypi_0
|
296 |
+
typing-extensions=4.1.1=hd3eb1b0_0
|
297 |
+
typing_extensions=4.1.1=pyh06a4308_0
|
298 |
+
tzdata=2022a=hda174b7_0
|
299 |
+
uc-micro-py=1.0.1=pypi_0
|
300 |
+
urllib3=1.26.9=py310h06a4308_0
|
301 |
+
uvicorn=0.18.2=pypi_0
|
302 |
+
wcwidth=0.2.5=pypi_0
|
303 |
+
webencodings=0.5.1=pypi_0
|
304 |
+
websocket-client=1.3.2=pypi_0
|
305 |
+
wheel=0.37.1=pyhd3eb1b0_0
|
306 |
+
x264=1!161.3030=h7f98852_1
|
307 |
+
x265=3.5=h924138e_3
|
308 |
+
xcb-util=0.4.0=h166bdaf_0
|
309 |
+
xcb-util-image=0.4.0=h166bdaf_0
|
310 |
+
xcb-util-keysyms=0.4.0=h166bdaf_0
|
311 |
+
xcb-util-renderutil=0.3.9=h166bdaf_0
|
312 |
+
xcb-util-wm=0.4.1=h166bdaf_0
|
313 |
+
xorg-fixesproto=5.0=h7f98852_1002
|
314 |
+
xorg-inputproto=2.3.2=h7f98852_1002
|
315 |
+
xorg-kbproto=1.0.7=h7f98852_1002
|
316 |
+
xorg-libice=1.0.10=h7f98852_0
|
317 |
+
xorg-libsm=1.2.3=hd9c2040_1000
|
318 |
+
xorg-libx11=1.7.2=h7f98852_0
|
319 |
+
xorg-libxau=1.0.9=h7f98852_0
|
320 |
+
xorg-libxdmcp=1.1.3=h7f98852_0
|
321 |
+
xorg-libxext=1.3.4=h7f98852_1
|
322 |
+
xorg-libxfixes=5.0.3=h7f98852_1004
|
323 |
+
xorg-libxi=1.7.10=h7f98852_0
|
324 |
+
xorg-libxrender=0.9.10=h7f98852_1003
|
325 |
+
xorg-renderproto=0.11.1=h7f98852_1002
|
326 |
+
xorg-xextproto=7.3.0=h7f98852_1002
|
327 |
+
xorg-xproto=7.0.31=h7f98852_1007
|
328 |
+
xz=5.2.5=h7f8727e_1
|
329 |
+
yaml=0.2.5=h7f98852_2
|
330 |
+
yarl=1.7.2=pypi_0
|
331 |
+
zfp=0.5.5=h9c3ff4c_8
|
332 |
+
zlib=1.2.12=h166bdaf_1
|
333 |
+
zlib-ng=2.0.6=h166bdaf_0
|
334 |
+
zstd=1.5.2=ha4553b6_0
|
src/.gitkeep
ADDED
File without changes
|
src/__pycache__/ds.cpython-310.pyc
ADDED
Binary file (14.6 kB). View file
|
|
src/__pycache__/losses.cpython-310.pyc
ADDED
Binary file (4.17 kB). View file
|
|
src/__pycache__/networks_SRGAN.cpython-310.pyc
ADDED
Binary file (6.99 kB). View file
|
|
src/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (34 kB). View file
|
|
src/app.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from matplotlib import cm
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.models as models
|
10 |
+
from torch.utils.data import Dataset, DataLoader
|
11 |
+
from torchvision import transforms
|
12 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
13 |
+
|
14 |
+
from PIL import Image
|
15 |
+
|
16 |
+
from ds import *
|
17 |
+
from losses import *
|
18 |
+
from networks_SRGAN import *
|
19 |
+
from utils import *
|
20 |
+
|
21 |
+
|
22 |
+
NetG = Generator()
|
23 |
+
model_parameters = filter(lambda p: True, NetG.parameters())
|
24 |
+
params = sum([np.prod(p.size()) for p in model_parameters])
|
25 |
+
print("Number of Parameters:",params)
|
26 |
+
NetC = BayesCap(in_channels=3, out_channels=3)
|
27 |
+
|
28 |
+
|
29 |
+
NetG = Generator()
|
30 |
+
NetG.load_state_dict(torch.load('../ckpt/srgan-ImageNet-bc347d67.pth', map_location='cuda:0'))
|
31 |
+
NetG.to('cuda')
|
32 |
+
NetG.eval()
|
33 |
+
|
34 |
+
NetC = BayesCap(in_channels=3, out_channels=3)
|
35 |
+
NetC.load_state_dict(torch.load('../ckpt/BayesCap_SRGAN_best.pth', map_location='cuda:0'))
|
36 |
+
NetC.to('cuda')
|
37 |
+
NetC.eval()
|
38 |
+
|
39 |
+
def tensor01_to_pil(xt):
|
40 |
+
r = transforms.ToPILImage(mode='RGB')(xt.squeeze())
|
41 |
+
return r
|
42 |
+
|
43 |
+
|
44 |
+
def predict(img):
|
45 |
+
"""
|
46 |
+
img: image
|
47 |
+
"""
|
48 |
+
image_size = (256,256)
|
49 |
+
upscale_factor = 4
|
50 |
+
lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
51 |
+
# lr_transforms = transforms.Resize((128, 128), interpolation=IMode.BICUBIC, antialias=True)
|
52 |
+
|
53 |
+
img = Image.fromarray(np.array(img))
|
54 |
+
img = lr_transforms(img)
|
55 |
+
lr_tensor = utils.image2tensor(img, range_norm=False, half=False)
|
56 |
+
|
57 |
+
device = 'cuda'
|
58 |
+
dtype = torch.cuda.FloatTensor
|
59 |
+
xLR = lr_tensor.to(device).unsqueeze(0)
|
60 |
+
xLR = xLR.type(dtype)
|
61 |
+
# pass them through the network
|
62 |
+
with torch.no_grad():
|
63 |
+
xSR = NetG(xLR)
|
64 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
65 |
+
|
66 |
+
a_map = (1/(xSRC_alpha[0] + 1e-5)).to('cpu').data
|
67 |
+
b_map = xSRC_beta[0].to('cpu').data
|
68 |
+
u_map = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
69 |
+
|
70 |
+
|
71 |
+
x_LR = tensor01_to_pil(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
72 |
+
|
73 |
+
x_mean = tensor01_to_pil(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
74 |
+
|
75 |
+
#im = Image.fromarray(np.uint8(cm.gist_earth(myarray)*255))
|
76 |
+
|
77 |
+
a_map = torch.clamp(a_map, min=0, max=0.1)
|
78 |
+
a_map = (a_map - a_map.min())/(a_map.max() - a_map.min())
|
79 |
+
x_alpha = Image.fromarray(np.uint8(cm.inferno(a_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
80 |
+
|
81 |
+
b_map = torch.clamp(b_map, min=0.45, max=0.75)
|
82 |
+
b_map = (b_map - b_map.min())/(b_map.max() - b_map.min())
|
83 |
+
x_beta = Image.fromarray(np.uint8(cm.cividis(b_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
84 |
+
|
85 |
+
u_map = torch.clamp(u_map, min=0, max=0.15)
|
86 |
+
u_map = (u_map - u_map.min())/(u_map.max() - u_map.min())
|
87 |
+
x_uncer = Image.fromarray(np.uint8(cm.hot(u_map.transpose(0,2).transpose(0,1).squeeze())*255))
|
88 |
+
|
89 |
+
return x_LR, x_mean, x_alpha, x_beta, x_uncer
|
90 |
+
|
91 |
+
import gradio as gr
|
92 |
+
|
93 |
+
title = "BayesCap"
|
94 |
+
description = "BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks (ECCV 2022)"
|
95 |
+
article = "<p style='text-align: center'> BayesCap: Bayesian Identity Cap for Calibrated Uncertainty in Frozen Neural Networks| <a href='https://github.com/ExplainableML/BayesCap'>Github Repo</a></p>"
|
96 |
+
|
97 |
+
|
98 |
+
gr.Interface(
|
99 |
+
fn=predict,
|
100 |
+
inputs=gr.inputs.Image(type='pil', label="Orignal"),
|
101 |
+
outputs=[
|
102 |
+
gr.outputs.Image(type='pil', label="Low-res"),
|
103 |
+
gr.outputs.Image(type='pil', label="Super-res"),
|
104 |
+
gr.outputs.Image(type='pil', label="Alpha"),
|
105 |
+
gr.outputs.Image(type='pil', label="Beta"),
|
106 |
+
gr.outputs.Image(type='pil', label="Uncertainty")
|
107 |
+
],
|
108 |
+
title=title,
|
109 |
+
description=description,
|
110 |
+
article=article,
|
111 |
+
examples=[
|
112 |
+
["../demo_examples/baby.png"],
|
113 |
+
["../demo_examples/bird.png"]
|
114 |
+
]
|
115 |
+
).launch(share=True)
|
src/ds.py
ADDED
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import absolute_import, division, print_function
|
2 |
+
|
3 |
+
import random
|
4 |
+
import copy
|
5 |
+
import io
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
import skimage.transform
|
10 |
+
from collections import Counter
|
11 |
+
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.utils.data as data
|
15 |
+
from torch import Tensor
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
from torchvision import transforms
|
18 |
+
from torchvision.transforms.functional import InterpolationMode as IMode
|
19 |
+
|
20 |
+
import utils
|
21 |
+
|
22 |
+
class ImgDset(Dataset):
|
23 |
+
"""Customize the data set loading function and prepare low/high resolution image data in advance.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
dataroot (str): Training data set address
|
27 |
+
image_size (int): High resolution image size
|
28 |
+
upscale_factor (int): Image magnification
|
29 |
+
mode (str): Data set loading method, the training data set is for data enhancement,
|
30 |
+
and the verification data set is not for data enhancement
|
31 |
+
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(self, dataroot: str, image_size: int, upscale_factor: int, mode: str) -> None:
|
35 |
+
super(ImgDset, self).__init__()
|
36 |
+
self.filenames = [os.path.join(dataroot, x) for x in os.listdir(dataroot)]
|
37 |
+
|
38 |
+
if mode == "train":
|
39 |
+
self.hr_transforms = transforms.Compose([
|
40 |
+
transforms.RandomCrop(image_size),
|
41 |
+
transforms.RandomRotation(90),
|
42 |
+
transforms.RandomHorizontalFlip(0.5),
|
43 |
+
])
|
44 |
+
else:
|
45 |
+
self.hr_transforms = transforms.Resize(image_size)
|
46 |
+
|
47 |
+
self.lr_transforms = transforms.Resize((image_size[0]//upscale_factor, image_size[1]//upscale_factor), interpolation=IMode.BICUBIC, antialias=True)
|
48 |
+
|
49 |
+
def __getitem__(self, batch_index: int) -> [Tensor, Tensor]:
|
50 |
+
# Read a batch of image data
|
51 |
+
image = Image.open(self.filenames[batch_index])
|
52 |
+
|
53 |
+
# Transform image
|
54 |
+
hr_image = self.hr_transforms(image)
|
55 |
+
lr_image = self.lr_transforms(hr_image)
|
56 |
+
|
57 |
+
# Convert image data into Tensor stream format (PyTorch).
|
58 |
+
# Note: The range of input and output is between [0, 1]
|
59 |
+
lr_tensor = utils.image2tensor(lr_image, range_norm=False, half=False)
|
60 |
+
hr_tensor = utils.image2tensor(hr_image, range_norm=False, half=False)
|
61 |
+
|
62 |
+
return lr_tensor, hr_tensor
|
63 |
+
|
64 |
+
def __len__(self) -> int:
|
65 |
+
return len(self.filenames)
|
66 |
+
|
67 |
+
|
68 |
+
class PairedImages_w_nameList(Dataset):
|
69 |
+
'''
|
70 |
+
can act as supervised or un-supervised based on flists
|
71 |
+
'''
|
72 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
73 |
+
self.flist1 = flist1
|
74 |
+
self.flist2 = flist2
|
75 |
+
self.transform1 = transform1
|
76 |
+
self.transform2 = transform2
|
77 |
+
self.do_aug = do_aug
|
78 |
+
def __getitem__(self, index):
|
79 |
+
impath1 = self.flist1[index]
|
80 |
+
img1 = Image.open(impath1).convert('RGB')
|
81 |
+
impath2 = self.flist2[index]
|
82 |
+
img2 = Image.open(impath2).convert('RGB')
|
83 |
+
|
84 |
+
img1 = utils.image2tensor(img1, range_norm=False, half=False)
|
85 |
+
img2 = utils.image2tensor(img2, range_norm=False, half=False)
|
86 |
+
|
87 |
+
if self.transform1 is not None:
|
88 |
+
img1 = self.transform1(img1)
|
89 |
+
if self.transform2 is not None:
|
90 |
+
img2 = self.transform2(img2)
|
91 |
+
|
92 |
+
return img1, img2
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.flist1)
|
95 |
+
|
96 |
+
class PairedImages_w_nameList_npy(Dataset):
|
97 |
+
'''
|
98 |
+
can act as supervised or un-supervised based on flists
|
99 |
+
'''
|
100 |
+
def __init__(self, flist1, flist2, transform1=None, transform2=None, do_aug=False):
|
101 |
+
self.flist1 = flist1
|
102 |
+
self.flist2 = flist2
|
103 |
+
self.transform1 = transform1
|
104 |
+
self.transform2 = transform2
|
105 |
+
self.do_aug = do_aug
|
106 |
+
def __getitem__(self, index):
|
107 |
+
impath1 = self.flist1[index]
|
108 |
+
img1 = np.load(impath1)
|
109 |
+
impath2 = self.flist2[index]
|
110 |
+
img2 = np.load(impath2)
|
111 |
+
|
112 |
+
if self.transform1 is not None:
|
113 |
+
img1 = self.transform1(img1)
|
114 |
+
if self.transform2 is not None:
|
115 |
+
img2 = self.transform2(img2)
|
116 |
+
|
117 |
+
return img1, img2
|
118 |
+
def __len__(self):
|
119 |
+
return len(self.flist1)
|
120 |
+
|
121 |
+
# def call_paired():
|
122 |
+
# root1='./GOPRO_3840FPS_AVG_3-21/train/blur/'
|
123 |
+
# root2='./GOPRO_3840FPS_AVG_3-21/train/sharp/'
|
124 |
+
|
125 |
+
# flist1=glob.glob(root1+'/*/*.png')
|
126 |
+
# flist2=glob.glob(root2+'/*/*.png')
|
127 |
+
|
128 |
+
# dset = PairedImages_w_nameList(root1,root2,flist1,flist2)
|
129 |
+
|
130 |
+
#### KITTI depth
|
131 |
+
|
132 |
+
def load_velodyne_points(filename):
|
133 |
+
"""Load 3D point cloud from KITTI file format
|
134 |
+
(adapted from https://github.com/hunse/kitti)
|
135 |
+
"""
|
136 |
+
points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4)
|
137 |
+
points[:, 3] = 1.0 # homogeneous
|
138 |
+
return points
|
139 |
+
|
140 |
+
|
141 |
+
def read_calib_file(path):
|
142 |
+
"""Read KITTI calibration file
|
143 |
+
(from https://github.com/hunse/kitti)
|
144 |
+
"""
|
145 |
+
float_chars = set("0123456789.e+- ")
|
146 |
+
data = {}
|
147 |
+
with open(path, 'r') as f:
|
148 |
+
for line in f.readlines():
|
149 |
+
key, value = line.split(':', 1)
|
150 |
+
value = value.strip()
|
151 |
+
data[key] = value
|
152 |
+
if float_chars.issuperset(value):
|
153 |
+
# try to cast to float array
|
154 |
+
try:
|
155 |
+
data[key] = np.array(list(map(float, value.split(' '))))
|
156 |
+
except ValueError:
|
157 |
+
# casting error: data[key] already eq. value, so pass
|
158 |
+
pass
|
159 |
+
|
160 |
+
return data
|
161 |
+
|
162 |
+
|
163 |
+
def sub2ind(matrixSize, rowSub, colSub):
|
164 |
+
"""Convert row, col matrix subscripts to linear indices
|
165 |
+
"""
|
166 |
+
m, n = matrixSize
|
167 |
+
return rowSub * (n-1) + colSub - 1
|
168 |
+
|
169 |
+
|
170 |
+
def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False):
|
171 |
+
"""Generate a depth map from velodyne data
|
172 |
+
"""
|
173 |
+
# load calibration files
|
174 |
+
cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt'))
|
175 |
+
velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt'))
|
176 |
+
velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis]))
|
177 |
+
velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0])))
|
178 |
+
|
179 |
+
# get image shape
|
180 |
+
im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32)
|
181 |
+
|
182 |
+
# compute projection matrix velodyne->image plane
|
183 |
+
R_cam2rect = np.eye(4)
|
184 |
+
R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3)
|
185 |
+
P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4)
|
186 |
+
P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam)
|
187 |
+
|
188 |
+
# load velodyne points and remove all behind image plane (approximation)
|
189 |
+
# each row of the velodyne data is forward, left, up, reflectance
|
190 |
+
velo = load_velodyne_points(velo_filename)
|
191 |
+
velo = velo[velo[:, 0] >= 0, :]
|
192 |
+
|
193 |
+
# project the points to the camera
|
194 |
+
velo_pts_im = np.dot(P_velo2im, velo.T).T
|
195 |
+
velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis]
|
196 |
+
|
197 |
+
if vel_depth:
|
198 |
+
velo_pts_im[:, 2] = velo[:, 0]
|
199 |
+
|
200 |
+
# check if in bounds
|
201 |
+
# use minus 1 to get the exact same value as KITTI matlab code
|
202 |
+
velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
|
203 |
+
velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
|
204 |
+
val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
|
205 |
+
val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0])
|
206 |
+
velo_pts_im = velo_pts_im[val_inds, :]
|
207 |
+
|
208 |
+
# project to image
|
209 |
+
depth = np.zeros((im_shape[:2]))
|
210 |
+
depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2]
|
211 |
+
|
212 |
+
# find the duplicate points and choose the closest depth
|
213 |
+
inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0])
|
214 |
+
dupe_inds = [item for item, count in Counter(inds).items() if count > 1]
|
215 |
+
for dd in dupe_inds:
|
216 |
+
pts = np.where(inds == dd)[0]
|
217 |
+
x_loc = int(velo_pts_im[pts[0], 0])
|
218 |
+
y_loc = int(velo_pts_im[pts[0], 1])
|
219 |
+
depth[y_loc, x_loc] = velo_pts_im[pts, 2].min()
|
220 |
+
depth[depth < 0] = 0
|
221 |
+
|
222 |
+
return depth
|
223 |
+
|
224 |
+
def pil_loader(path):
|
225 |
+
# open path as file to avoid ResourceWarning
|
226 |
+
# (https://github.com/python-pillow/Pillow/issues/835)
|
227 |
+
with open(path, 'rb') as f:
|
228 |
+
with Image.open(f) as img:
|
229 |
+
return img.convert('RGB')
|
230 |
+
|
231 |
+
|
232 |
+
class MonoDataset(data.Dataset):
|
233 |
+
"""Superclass for monocular dataloaders
|
234 |
+
|
235 |
+
Args:
|
236 |
+
data_path
|
237 |
+
filenames
|
238 |
+
height
|
239 |
+
width
|
240 |
+
frame_idxs
|
241 |
+
num_scales
|
242 |
+
is_train
|
243 |
+
img_ext
|
244 |
+
"""
|
245 |
+
def __init__(self,
|
246 |
+
data_path,
|
247 |
+
filenames,
|
248 |
+
height,
|
249 |
+
width,
|
250 |
+
frame_idxs,
|
251 |
+
num_scales,
|
252 |
+
is_train=False,
|
253 |
+
img_ext='.jpg'):
|
254 |
+
super(MonoDataset, self).__init__()
|
255 |
+
|
256 |
+
self.data_path = data_path
|
257 |
+
self.filenames = filenames
|
258 |
+
self.height = height
|
259 |
+
self.width = width
|
260 |
+
self.num_scales = num_scales
|
261 |
+
self.interp = Image.ANTIALIAS
|
262 |
+
|
263 |
+
self.frame_idxs = frame_idxs
|
264 |
+
|
265 |
+
self.is_train = is_train
|
266 |
+
self.img_ext = img_ext
|
267 |
+
|
268 |
+
self.loader = pil_loader
|
269 |
+
self.to_tensor = transforms.ToTensor()
|
270 |
+
|
271 |
+
# We need to specify augmentations differently in newer versions of torchvision.
|
272 |
+
# We first try the newer tuple version; if this fails we fall back to scalars
|
273 |
+
try:
|
274 |
+
self.brightness = (0.8, 1.2)
|
275 |
+
self.contrast = (0.8, 1.2)
|
276 |
+
self.saturation = (0.8, 1.2)
|
277 |
+
self.hue = (-0.1, 0.1)
|
278 |
+
transforms.ColorJitter.get_params(
|
279 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
280 |
+
except TypeError:
|
281 |
+
self.brightness = 0.2
|
282 |
+
self.contrast = 0.2
|
283 |
+
self.saturation = 0.2
|
284 |
+
self.hue = 0.1
|
285 |
+
|
286 |
+
self.resize = {}
|
287 |
+
for i in range(self.num_scales):
|
288 |
+
s = 2 ** i
|
289 |
+
self.resize[i] = transforms.Resize((self.height // s, self.width // s),
|
290 |
+
interpolation=self.interp)
|
291 |
+
|
292 |
+
self.load_depth = self.check_depth()
|
293 |
+
|
294 |
+
def preprocess(self, inputs, color_aug):
|
295 |
+
"""Resize colour images to the required scales and augment if required
|
296 |
+
|
297 |
+
We create the color_aug object in advance and apply the same augmentation to all
|
298 |
+
images in this item. This ensures that all images input to the pose network receive the
|
299 |
+
same augmentation.
|
300 |
+
"""
|
301 |
+
for k in list(inputs):
|
302 |
+
frame = inputs[k]
|
303 |
+
if "color" in k:
|
304 |
+
n, im, i = k
|
305 |
+
for i in range(self.num_scales):
|
306 |
+
inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)])
|
307 |
+
|
308 |
+
for k in list(inputs):
|
309 |
+
f = inputs[k]
|
310 |
+
if "color" in k:
|
311 |
+
n, im, i = k
|
312 |
+
inputs[(n, im, i)] = self.to_tensor(f)
|
313 |
+
inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
|
314 |
+
|
315 |
+
def __len__(self):
|
316 |
+
return len(self.filenames)
|
317 |
+
|
318 |
+
def __getitem__(self, index):
|
319 |
+
"""Returns a single training item from the dataset as a dictionary.
|
320 |
+
|
321 |
+
Values correspond to torch tensors.
|
322 |
+
Keys in the dictionary are either strings or tuples:
|
323 |
+
|
324 |
+
("color", <frame_id>, <scale>) for raw colour images,
|
325 |
+
("color_aug", <frame_id>, <scale>) for augmented colour images,
|
326 |
+
("K", scale) or ("inv_K", scale) for camera intrinsics,
|
327 |
+
"stereo_T" for camera extrinsics, and
|
328 |
+
"depth_gt" for ground truth depth maps.
|
329 |
+
|
330 |
+
<frame_id> is either:
|
331 |
+
an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
|
332 |
+
or
|
333 |
+
"s" for the opposite image in the stereo pair.
|
334 |
+
|
335 |
+
<scale> is an integer representing the scale of the image relative to the fullsize image:
|
336 |
+
-1 images at native resolution as loaded from disk
|
337 |
+
0 images resized to (self.width, self.height )
|
338 |
+
1 images resized to (self.width // 2, self.height // 2)
|
339 |
+
2 images resized to (self.width // 4, self.height // 4)
|
340 |
+
3 images resized to (self.width // 8, self.height // 8)
|
341 |
+
"""
|
342 |
+
inputs = {}
|
343 |
+
|
344 |
+
do_color_aug = self.is_train and random.random() > 0.5
|
345 |
+
do_flip = self.is_train and random.random() > 0.5
|
346 |
+
|
347 |
+
line = self.filenames[index].split()
|
348 |
+
folder = line[0]
|
349 |
+
|
350 |
+
if len(line) == 3:
|
351 |
+
frame_index = int(line[1])
|
352 |
+
else:
|
353 |
+
frame_index = 0
|
354 |
+
|
355 |
+
if len(line) == 3:
|
356 |
+
side = line[2]
|
357 |
+
else:
|
358 |
+
side = None
|
359 |
+
|
360 |
+
for i in self.frame_idxs:
|
361 |
+
if i == "s":
|
362 |
+
other_side = {"r": "l", "l": "r"}[side]
|
363 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip)
|
364 |
+
else:
|
365 |
+
inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip)
|
366 |
+
|
367 |
+
# adjusting intrinsics to match each scale in the pyramid
|
368 |
+
for scale in range(self.num_scales):
|
369 |
+
K = self.K.copy()
|
370 |
+
|
371 |
+
K[0, :] *= self.width // (2 ** scale)
|
372 |
+
K[1, :] *= self.height // (2 ** scale)
|
373 |
+
|
374 |
+
inv_K = np.linalg.pinv(K)
|
375 |
+
|
376 |
+
inputs[("K", scale)] = torch.from_numpy(K)
|
377 |
+
inputs[("inv_K", scale)] = torch.from_numpy(inv_K)
|
378 |
+
|
379 |
+
if do_color_aug:
|
380 |
+
color_aug = transforms.ColorJitter.get_params(
|
381 |
+
self.brightness, self.contrast, self.saturation, self.hue)
|
382 |
+
else:
|
383 |
+
color_aug = (lambda x: x)
|
384 |
+
|
385 |
+
self.preprocess(inputs, color_aug)
|
386 |
+
|
387 |
+
for i in self.frame_idxs:
|
388 |
+
del inputs[("color", i, -1)]
|
389 |
+
del inputs[("color_aug", i, -1)]
|
390 |
+
|
391 |
+
if self.load_depth:
|
392 |
+
depth_gt = self.get_depth(folder, frame_index, side, do_flip)
|
393 |
+
inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
|
394 |
+
inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32))
|
395 |
+
|
396 |
+
if "s" in self.frame_idxs:
|
397 |
+
stereo_T = np.eye(4, dtype=np.float32)
|
398 |
+
baseline_sign = -1 if do_flip else 1
|
399 |
+
side_sign = -1 if side == "l" else 1
|
400 |
+
stereo_T[0, 3] = side_sign * baseline_sign * 0.1
|
401 |
+
|
402 |
+
inputs["stereo_T"] = torch.from_numpy(stereo_T)
|
403 |
+
|
404 |
+
return inputs
|
405 |
+
|
406 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
407 |
+
raise NotImplementedError
|
408 |
+
|
409 |
+
def check_depth(self):
|
410 |
+
raise NotImplementedError
|
411 |
+
|
412 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
413 |
+
raise NotImplementedError
|
414 |
+
|
415 |
+
class KITTIDataset(MonoDataset):
|
416 |
+
"""Superclass for different types of KITTI dataset loaders
|
417 |
+
"""
|
418 |
+
def __init__(self, *args, **kwargs):
|
419 |
+
super(KITTIDataset, self).__init__(*args, **kwargs)
|
420 |
+
|
421 |
+
# NOTE: Make sure your intrinsics matrix is *normalized* by the original image size.
|
422 |
+
# To normalize you need to scale the first row by 1 / image_width and the second row
|
423 |
+
# by 1 / image_height. Monodepth2 assumes a principal point to be exactly centered.
|
424 |
+
# If your principal point is far from the center you might need to disable the horizontal
|
425 |
+
# flip augmentation.
|
426 |
+
self.K = np.array([[0.58, 0, 0.5, 0],
|
427 |
+
[0, 1.92, 0.5, 0],
|
428 |
+
[0, 0, 1, 0],
|
429 |
+
[0, 0, 0, 1]], dtype=np.float32)
|
430 |
+
|
431 |
+
self.full_res_shape = (1242, 375)
|
432 |
+
self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3}
|
433 |
+
|
434 |
+
def check_depth(self):
|
435 |
+
line = self.filenames[0].split()
|
436 |
+
scene_name = line[0]
|
437 |
+
frame_index = int(line[1])
|
438 |
+
|
439 |
+
velo_filename = os.path.join(
|
440 |
+
self.data_path,
|
441 |
+
scene_name,
|
442 |
+
"velodyne_points/data/{:010d}.bin".format(int(frame_index)))
|
443 |
+
|
444 |
+
return os.path.isfile(velo_filename)
|
445 |
+
|
446 |
+
def get_color(self, folder, frame_index, side, do_flip):
|
447 |
+
color = self.loader(self.get_image_path(folder, frame_index, side))
|
448 |
+
|
449 |
+
if do_flip:
|
450 |
+
color = color.transpose(Image.FLIP_LEFT_RIGHT)
|
451 |
+
|
452 |
+
return color
|
453 |
+
|
454 |
+
|
455 |
+
class KITTIDepthDataset(KITTIDataset):
|
456 |
+
"""KITTI dataset which uses the updated ground truth depth maps
|
457 |
+
"""
|
458 |
+
def __init__(self, *args, **kwargs):
|
459 |
+
super(KITTIDepthDataset, self).__init__(*args, **kwargs)
|
460 |
+
|
461 |
+
def get_image_path(self, folder, frame_index, side):
|
462 |
+
f_str = "{:010d}{}".format(frame_index, self.img_ext)
|
463 |
+
image_path = os.path.join(
|
464 |
+
self.data_path,
|
465 |
+
folder,
|
466 |
+
"image_0{}/data".format(self.side_map[side]),
|
467 |
+
f_str)
|
468 |
+
return image_path
|
469 |
+
|
470 |
+
def get_depth(self, folder, frame_index, side, do_flip):
|
471 |
+
f_str = "{:010d}.png".format(frame_index)
|
472 |
+
depth_path = os.path.join(
|
473 |
+
self.data_path,
|
474 |
+
folder,
|
475 |
+
"proj_depth/groundtruth/image_0{}".format(self.side_map[side]),
|
476 |
+
f_str)
|
477 |
+
|
478 |
+
depth_gt = Image.open(depth_path)
|
479 |
+
depth_gt = depth_gt.resize(self.full_res_shape, Image.NEAREST)
|
480 |
+
depth_gt = np.array(depth_gt).astype(np.float32) / 256
|
481 |
+
|
482 |
+
if do_flip:
|
483 |
+
depth_gt = np.fliplr(depth_gt)
|
484 |
+
|
485 |
+
return depth_gt
|
src/flagged/Alpha/0.png
ADDED
src/flagged/Beta/0.png
ADDED
src/flagged/Low-res/0.png
ADDED
src/flagged/Orignal/0.png
ADDED
src/flagged/Super-res/0.png
ADDED
src/flagged/Uncertainty/0.png
ADDED
src/flagged/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
'Orignal','Low-res','Super-res','Alpha','Beta','Uncertainty','flag','username','timestamp'
|
2 |
+
'Orignal/0.png','Low-res/0.png','Super-res/0.png','Alpha/0.png','Beta/0.png','Uncertainty/0.png','','','2022-07-09 14:01:12.964411'
|
src/losses.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as models
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
class ContentLoss(nn.Module):
|
8 |
+
"""Constructs a content loss function based on the VGG19 network.
|
9 |
+
Using high-level feature mapping layers from the latter layers will focus more on the texture content of the image.
|
10 |
+
|
11 |
+
Paper reference list:
|
12 |
+
-`Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network <https://arxiv.org/pdf/1609.04802.pdf>` paper.
|
13 |
+
-`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks <https://arxiv.org/pdf/1809.00219.pdf>` paper.
|
14 |
+
-`Perceptual Extreme Super Resolution Network with Receptive Field Block <https://arxiv.org/pdf/2005.12597.pdf>` paper.
|
15 |
+
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self) -> None:
|
19 |
+
super(ContentLoss, self).__init__()
|
20 |
+
# Load the VGG19 model trained on the ImageNet dataset.
|
21 |
+
vgg19 = models.vgg19(pretrained=True).eval()
|
22 |
+
# Extract the thirty-sixth layer output in the VGG19 model as the content loss.
|
23 |
+
self.feature_extractor = nn.Sequential(*list(vgg19.features.children())[:36])
|
24 |
+
# Freeze model parameters.
|
25 |
+
for parameters in self.feature_extractor.parameters():
|
26 |
+
parameters.requires_grad = False
|
27 |
+
|
28 |
+
# The preprocessing method of the input data. This is the VGG model preprocessing method of the ImageNet dataset.
|
29 |
+
self.register_buffer("mean", torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
30 |
+
self.register_buffer("std", torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
31 |
+
|
32 |
+
def forward(self, sr: Tensor, hr: Tensor) -> Tensor:
|
33 |
+
# Standardized operations
|
34 |
+
sr = sr.sub(self.mean).div(self.std)
|
35 |
+
hr = hr.sub(self.mean).div(self.std)
|
36 |
+
|
37 |
+
# Find the feature map difference between the two images
|
38 |
+
loss = F.l1_loss(self.feature_extractor(sr), self.feature_extractor(hr))
|
39 |
+
|
40 |
+
return loss
|
41 |
+
|
42 |
+
|
43 |
+
class GenGaussLoss(nn.Module):
|
44 |
+
def __init__(
|
45 |
+
self, reduction='mean',
|
46 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
47 |
+
resi_min = 1e-4, resi_max=1e3
|
48 |
+
) -> None:
|
49 |
+
super(GenGaussLoss, self).__init__()
|
50 |
+
self.reduction = reduction
|
51 |
+
self.alpha_eps = alpha_eps
|
52 |
+
self.beta_eps = beta_eps
|
53 |
+
self.resi_min = resi_min
|
54 |
+
self.resi_max = resi_max
|
55 |
+
|
56 |
+
def forward(
|
57 |
+
self,
|
58 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor
|
59 |
+
):
|
60 |
+
one_over_alpha1 = one_over_alpha + self.alpha_eps
|
61 |
+
beta1 = beta + self.beta_eps
|
62 |
+
|
63 |
+
resi = torch.abs(mean - target)
|
64 |
+
# resi = torch.pow(resi*one_over_alpha1, beta1).clamp(min=self.resi_min, max=self.resi_max)
|
65 |
+
resi = (resi*one_over_alpha1*beta1).clamp(min=self.resi_min, max=self.resi_max)
|
66 |
+
## check if resi has nans
|
67 |
+
if torch.sum(resi != resi) > 0:
|
68 |
+
print('resi has nans!!')
|
69 |
+
return None
|
70 |
+
|
71 |
+
log_one_over_alpha = torch.log(one_over_alpha1)
|
72 |
+
log_beta = torch.log(beta1)
|
73 |
+
lgamma_beta = torch.lgamma(torch.pow(beta1, -1))
|
74 |
+
|
75 |
+
if torch.sum(log_one_over_alpha != log_one_over_alpha) > 0:
|
76 |
+
print('log_one_over_alpha has nan')
|
77 |
+
if torch.sum(lgamma_beta != lgamma_beta) > 0:
|
78 |
+
print('lgamma_beta has nan')
|
79 |
+
if torch.sum(log_beta != log_beta) > 0:
|
80 |
+
print('log_beta has nan')
|
81 |
+
|
82 |
+
l = resi - log_one_over_alpha + lgamma_beta - log_beta
|
83 |
+
|
84 |
+
if self.reduction == 'mean':
|
85 |
+
return l.mean()
|
86 |
+
elif self.reduction == 'sum':
|
87 |
+
return l.sum()
|
88 |
+
else:
|
89 |
+
print('Reduction not supported')
|
90 |
+
return None
|
91 |
+
|
92 |
+
class TempCombLoss(nn.Module):
|
93 |
+
def __init__(
|
94 |
+
self, reduction='mean',
|
95 |
+
alpha_eps = 1e-4, beta_eps=1e-4,
|
96 |
+
resi_min = 1e-4, resi_max=1e3
|
97 |
+
) -> None:
|
98 |
+
super(TempCombLoss, self).__init__()
|
99 |
+
self.reduction = reduction
|
100 |
+
self.alpha_eps = alpha_eps
|
101 |
+
self.beta_eps = beta_eps
|
102 |
+
self.resi_min = resi_min
|
103 |
+
self.resi_max = resi_max
|
104 |
+
|
105 |
+
self.L_GenGauss = GenGaussLoss(
|
106 |
+
reduction=self.reduction,
|
107 |
+
alpha_eps=self.alpha_eps, beta_eps=self.beta_eps,
|
108 |
+
resi_min=self.resi_min, resi_max=self.resi_max
|
109 |
+
)
|
110 |
+
self.L_l1 = nn.L1Loss(reduction=self.reduction)
|
111 |
+
|
112 |
+
def forward(
|
113 |
+
self,
|
114 |
+
mean: Tensor, one_over_alpha: Tensor, beta: Tensor, target: Tensor,
|
115 |
+
T1: float, T2: float
|
116 |
+
):
|
117 |
+
l1 = self.L_l1(mean, target)
|
118 |
+
l2 = self.L_GenGauss(mean, one_over_alpha, beta, target)
|
119 |
+
l = T1*l1 + T2*l2
|
120 |
+
|
121 |
+
return l
|
122 |
+
|
123 |
+
|
124 |
+
# x1 = torch.randn(4,3,32,32)
|
125 |
+
# x2 = torch.rand(4,3,32,32)
|
126 |
+
# x3 = torch.rand(4,3,32,32)
|
127 |
+
# x4 = torch.randn(4,3,32,32)
|
128 |
+
|
129 |
+
# L = GenGaussLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
130 |
+
# L2 = TempCombLoss(alpha_eps=1e-4, beta_eps=1e-4, resi_min=1e-4, resi_max=1e3)
|
131 |
+
# print(L(x1, x2, x3, x4), L2(x1, x2, x3, x4, 1e0, 1e-2))
|
src/networks_SRGAN.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision.models as models
|
5 |
+
from torch import Tensor
|
6 |
+
|
7 |
+
# __all__ = [
|
8 |
+
# "ResidualConvBlock",
|
9 |
+
# "Discriminator", "Generator",
|
10 |
+
# ]
|
11 |
+
|
12 |
+
|
13 |
+
class ResidualConvBlock(nn.Module):
|
14 |
+
"""Implements residual conv function.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
channels (int): Number of channels in the input image.
|
18 |
+
"""
|
19 |
+
|
20 |
+
def __init__(self, channels: int) -> None:
|
21 |
+
super(ResidualConvBlock, self).__init__()
|
22 |
+
self.rcb = nn.Sequential(
|
23 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
24 |
+
nn.BatchNorm2d(channels),
|
25 |
+
nn.PReLU(),
|
26 |
+
nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1), bias=False),
|
27 |
+
nn.BatchNorm2d(channels),
|
28 |
+
)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
identity = x
|
32 |
+
|
33 |
+
out = self.rcb(x)
|
34 |
+
out = torch.add(out, identity)
|
35 |
+
|
36 |
+
return out
|
37 |
+
|
38 |
+
|
39 |
+
class Discriminator(nn.Module):
|
40 |
+
def __init__(self) -> None:
|
41 |
+
super(Discriminator, self).__init__()
|
42 |
+
self.features = nn.Sequential(
|
43 |
+
# input size. (3) x 96 x 96
|
44 |
+
nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
45 |
+
nn.LeakyReLU(0.2, True),
|
46 |
+
# state size. (64) x 48 x 48
|
47 |
+
nn.Conv2d(64, 64, (3, 3), (2, 2), (1, 1), bias=False),
|
48 |
+
nn.BatchNorm2d(64),
|
49 |
+
nn.LeakyReLU(0.2, True),
|
50 |
+
nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
|
51 |
+
nn.BatchNorm2d(128),
|
52 |
+
nn.LeakyReLU(0.2, True),
|
53 |
+
# state size. (128) x 24 x 24
|
54 |
+
nn.Conv2d(128, 128, (3, 3), (2, 2), (1, 1), bias=False),
|
55 |
+
nn.BatchNorm2d(128),
|
56 |
+
nn.LeakyReLU(0.2, True),
|
57 |
+
nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
|
58 |
+
nn.BatchNorm2d(256),
|
59 |
+
nn.LeakyReLU(0.2, True),
|
60 |
+
# state size. (256) x 12 x 12
|
61 |
+
nn.Conv2d(256, 256, (3, 3), (2, 2), (1, 1), bias=False),
|
62 |
+
nn.BatchNorm2d(256),
|
63 |
+
nn.LeakyReLU(0.2, True),
|
64 |
+
nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
|
65 |
+
nn.BatchNorm2d(512),
|
66 |
+
nn.LeakyReLU(0.2, True),
|
67 |
+
# state size. (512) x 6 x 6
|
68 |
+
nn.Conv2d(512, 512, (3, 3), (2, 2), (1, 1), bias=False),
|
69 |
+
nn.BatchNorm2d(512),
|
70 |
+
nn.LeakyReLU(0.2, True),
|
71 |
+
)
|
72 |
+
|
73 |
+
self.classifier = nn.Sequential(
|
74 |
+
nn.Linear(512 * 6 * 6, 1024),
|
75 |
+
nn.LeakyReLU(0.2, True),
|
76 |
+
nn.Linear(1024, 1),
|
77 |
+
)
|
78 |
+
|
79 |
+
def forward(self, x: Tensor) -> Tensor:
|
80 |
+
out = self.features(x)
|
81 |
+
out = torch.flatten(out, 1)
|
82 |
+
out = self.classifier(out)
|
83 |
+
|
84 |
+
return out
|
85 |
+
|
86 |
+
|
87 |
+
class Generator(nn.Module):
|
88 |
+
def __init__(self) -> None:
|
89 |
+
super(Generator, self).__init__()
|
90 |
+
# First conv layer.
|
91 |
+
self.conv_block1 = nn.Sequential(
|
92 |
+
nn.Conv2d(3, 64, (9, 9), (1, 1), (4, 4)),
|
93 |
+
nn.PReLU(),
|
94 |
+
)
|
95 |
+
|
96 |
+
# Features trunk blocks.
|
97 |
+
trunk = []
|
98 |
+
for _ in range(16):
|
99 |
+
trunk.append(ResidualConvBlock(64))
|
100 |
+
self.trunk = nn.Sequential(*trunk)
|
101 |
+
|
102 |
+
# Second conv layer.
|
103 |
+
self.conv_block2 = nn.Sequential(
|
104 |
+
nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1), bias=False),
|
105 |
+
nn.BatchNorm2d(64),
|
106 |
+
)
|
107 |
+
|
108 |
+
# Upscale conv block.
|
109 |
+
self.upsampling = nn.Sequential(
|
110 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
111 |
+
nn.PixelShuffle(2),
|
112 |
+
nn.PReLU(),
|
113 |
+
nn.Conv2d(64, 256, (3, 3), (1, 1), (1, 1)),
|
114 |
+
nn.PixelShuffle(2),
|
115 |
+
nn.PReLU(),
|
116 |
+
)
|
117 |
+
|
118 |
+
# Output layer.
|
119 |
+
self.conv_block3 = nn.Conv2d(64, 3, (9, 9), (1, 1), (4, 4))
|
120 |
+
|
121 |
+
# Initialize neural network weights.
|
122 |
+
self._initialize_weights()
|
123 |
+
|
124 |
+
def forward(self, x: Tensor, dop=None) -> Tensor:
|
125 |
+
if not dop:
|
126 |
+
return self._forward_impl(x)
|
127 |
+
else:
|
128 |
+
return self._forward_w_dop_impl(x, dop)
|
129 |
+
|
130 |
+
# Support torch.script function.
|
131 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
132 |
+
out1 = self.conv_block1(x)
|
133 |
+
out = self.trunk(out1)
|
134 |
+
out2 = self.conv_block2(out)
|
135 |
+
out = torch.add(out1, out2)
|
136 |
+
out = self.upsampling(out)
|
137 |
+
out = self.conv_block3(out)
|
138 |
+
|
139 |
+
return out
|
140 |
+
|
141 |
+
def _forward_w_dop_impl(self, x: Tensor, dop) -> Tensor:
|
142 |
+
out1 = self.conv_block1(x)
|
143 |
+
out = self.trunk(out1)
|
144 |
+
out2 = F.dropout2d(self.conv_block2(out), p=dop)
|
145 |
+
out = torch.add(out1, out2)
|
146 |
+
out = self.upsampling(out)
|
147 |
+
out = self.conv_block3(out)
|
148 |
+
|
149 |
+
return out
|
150 |
+
|
151 |
+
def _initialize_weights(self) -> None:
|
152 |
+
for module in self.modules():
|
153 |
+
if isinstance(module, nn.Conv2d):
|
154 |
+
nn.init.kaiming_normal_(module.weight)
|
155 |
+
if module.bias is not None:
|
156 |
+
nn.init.constant_(module.bias, 0)
|
157 |
+
elif isinstance(module, nn.BatchNorm2d):
|
158 |
+
nn.init.constant_(module.weight, 1)
|
159 |
+
|
160 |
+
|
161 |
+
#### BayesCap
|
162 |
+
class BayesCap(nn.Module):
|
163 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
164 |
+
super(BayesCap, self).__init__()
|
165 |
+
# First conv layer.
|
166 |
+
self.conv_block1 = nn.Sequential(
|
167 |
+
nn.Conv2d(
|
168 |
+
in_channels, 64,
|
169 |
+
kernel_size=9, stride=1, padding=4
|
170 |
+
),
|
171 |
+
nn.PReLU(),
|
172 |
+
)
|
173 |
+
|
174 |
+
# Features trunk blocks.
|
175 |
+
trunk = []
|
176 |
+
for _ in range(16):
|
177 |
+
trunk.append(ResidualConvBlock(64))
|
178 |
+
self.trunk = nn.Sequential(*trunk)
|
179 |
+
|
180 |
+
# Second conv layer.
|
181 |
+
self.conv_block2 = nn.Sequential(
|
182 |
+
nn.Conv2d(
|
183 |
+
64, 64,
|
184 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
185 |
+
),
|
186 |
+
nn.BatchNorm2d(64),
|
187 |
+
)
|
188 |
+
|
189 |
+
# Output layer.
|
190 |
+
self.conv_block3_mu = nn.Conv2d(
|
191 |
+
64, out_channels=out_channels,
|
192 |
+
kernel_size=9, stride=1, padding=4
|
193 |
+
)
|
194 |
+
self.conv_block3_alpha = nn.Sequential(
|
195 |
+
nn.Conv2d(
|
196 |
+
64, 64,
|
197 |
+
kernel_size=9, stride=1, padding=4
|
198 |
+
),
|
199 |
+
nn.PReLU(),
|
200 |
+
nn.Conv2d(
|
201 |
+
64, 64,
|
202 |
+
kernel_size=9, stride=1, padding=4
|
203 |
+
),
|
204 |
+
nn.PReLU(),
|
205 |
+
nn.Conv2d(
|
206 |
+
64, 1,
|
207 |
+
kernel_size=9, stride=1, padding=4
|
208 |
+
),
|
209 |
+
nn.ReLU(),
|
210 |
+
)
|
211 |
+
self.conv_block3_beta = nn.Sequential(
|
212 |
+
nn.Conv2d(
|
213 |
+
64, 64,
|
214 |
+
kernel_size=9, stride=1, padding=4
|
215 |
+
),
|
216 |
+
nn.PReLU(),
|
217 |
+
nn.Conv2d(
|
218 |
+
64, 64,
|
219 |
+
kernel_size=9, stride=1, padding=4
|
220 |
+
),
|
221 |
+
nn.PReLU(),
|
222 |
+
nn.Conv2d(
|
223 |
+
64, 1,
|
224 |
+
kernel_size=9, stride=1, padding=4
|
225 |
+
),
|
226 |
+
nn.ReLU(),
|
227 |
+
)
|
228 |
+
|
229 |
+
# Initialize neural network weights.
|
230 |
+
self._initialize_weights()
|
231 |
+
|
232 |
+
def forward(self, x: Tensor) -> Tensor:
|
233 |
+
return self._forward_impl(x)
|
234 |
+
|
235 |
+
# Support torch.script function.
|
236 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
237 |
+
out1 = self.conv_block1(x)
|
238 |
+
out = self.trunk(out1)
|
239 |
+
out2 = self.conv_block2(out)
|
240 |
+
out = out1 + out2
|
241 |
+
out_mu = self.conv_block3_mu(out)
|
242 |
+
out_alpha = self.conv_block3_alpha(out)
|
243 |
+
out_beta = self.conv_block3_beta(out)
|
244 |
+
return out_mu, out_alpha, out_beta
|
245 |
+
|
246 |
+
def _initialize_weights(self) -> None:
|
247 |
+
for module in self.modules():
|
248 |
+
if isinstance(module, nn.Conv2d):
|
249 |
+
nn.init.kaiming_normal_(module.weight)
|
250 |
+
if module.bias is not None:
|
251 |
+
nn.init.constant_(module.bias, 0)
|
252 |
+
elif isinstance(module, nn.BatchNorm2d):
|
253 |
+
nn.init.constant_(module.weight, 1)
|
254 |
+
|
255 |
+
|
256 |
+
class BayesCap_noID(nn.Module):
|
257 |
+
def __init__(self, in_channels=3, out_channels=3) -> None:
|
258 |
+
super(BayesCap_noID, self).__init__()
|
259 |
+
# First conv layer.
|
260 |
+
self.conv_block1 = nn.Sequential(
|
261 |
+
nn.Conv2d(
|
262 |
+
in_channels, 64,
|
263 |
+
kernel_size=9, stride=1, padding=4
|
264 |
+
),
|
265 |
+
nn.PReLU(),
|
266 |
+
)
|
267 |
+
|
268 |
+
# Features trunk blocks.
|
269 |
+
trunk = []
|
270 |
+
for _ in range(16):
|
271 |
+
trunk.append(ResidualConvBlock(64))
|
272 |
+
self.trunk = nn.Sequential(*trunk)
|
273 |
+
|
274 |
+
# Second conv layer.
|
275 |
+
self.conv_block2 = nn.Sequential(
|
276 |
+
nn.Conv2d(
|
277 |
+
64, 64,
|
278 |
+
kernel_size=3, stride=1, padding=1, bias=False
|
279 |
+
),
|
280 |
+
nn.BatchNorm2d(64),
|
281 |
+
)
|
282 |
+
|
283 |
+
# Output layer.
|
284 |
+
# self.conv_block3_mu = nn.Conv2d(
|
285 |
+
# 64, out_channels=out_channels,
|
286 |
+
# kernel_size=9, stride=1, padding=4
|
287 |
+
# )
|
288 |
+
self.conv_block3_alpha = nn.Sequential(
|
289 |
+
nn.Conv2d(
|
290 |
+
64, 64,
|
291 |
+
kernel_size=9, stride=1, padding=4
|
292 |
+
),
|
293 |
+
nn.PReLU(),
|
294 |
+
nn.Conv2d(
|
295 |
+
64, 64,
|
296 |
+
kernel_size=9, stride=1, padding=4
|
297 |
+
),
|
298 |
+
nn.PReLU(),
|
299 |
+
nn.Conv2d(
|
300 |
+
64, 1,
|
301 |
+
kernel_size=9, stride=1, padding=4
|
302 |
+
),
|
303 |
+
nn.ReLU(),
|
304 |
+
)
|
305 |
+
self.conv_block3_beta = nn.Sequential(
|
306 |
+
nn.Conv2d(
|
307 |
+
64, 64,
|
308 |
+
kernel_size=9, stride=1, padding=4
|
309 |
+
),
|
310 |
+
nn.PReLU(),
|
311 |
+
nn.Conv2d(
|
312 |
+
64, 64,
|
313 |
+
kernel_size=9, stride=1, padding=4
|
314 |
+
),
|
315 |
+
nn.PReLU(),
|
316 |
+
nn.Conv2d(
|
317 |
+
64, 1,
|
318 |
+
kernel_size=9, stride=1, padding=4
|
319 |
+
),
|
320 |
+
nn.ReLU(),
|
321 |
+
)
|
322 |
+
|
323 |
+
# Initialize neural network weights.
|
324 |
+
self._initialize_weights()
|
325 |
+
|
326 |
+
def forward(self, x: Tensor) -> Tensor:
|
327 |
+
return self._forward_impl(x)
|
328 |
+
|
329 |
+
# Support torch.script function.
|
330 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
331 |
+
out1 = self.conv_block1(x)
|
332 |
+
out = self.trunk(out1)
|
333 |
+
out2 = self.conv_block2(out)
|
334 |
+
out = out1 + out2
|
335 |
+
# out_mu = self.conv_block3_mu(out)
|
336 |
+
out_alpha = self.conv_block3_alpha(out)
|
337 |
+
out_beta = self.conv_block3_beta(out)
|
338 |
+
return out_alpha, out_beta
|
339 |
+
|
340 |
+
def _initialize_weights(self) -> None:
|
341 |
+
for module in self.modules():
|
342 |
+
if isinstance(module, nn.Conv2d):
|
343 |
+
nn.init.kaiming_normal_(module.weight)
|
344 |
+
if module.bias is not None:
|
345 |
+
nn.init.constant_(module.bias, 0)
|
346 |
+
elif isinstance(module, nn.BatchNorm2d):
|
347 |
+
nn.init.constant_(module.weight, 1)
|
src/networks_T1toT2.py
ADDED
@@ -0,0 +1,477 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import functools
|
5 |
+
|
6 |
+
### components
|
7 |
+
class ResConv(nn.Module):
|
8 |
+
"""
|
9 |
+
Residual convolutional block, where
|
10 |
+
convolutional block consists: (convolution => [BN] => ReLU) * 3
|
11 |
+
residual connection adds the input to the output
|
12 |
+
"""
|
13 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
14 |
+
super().__init__()
|
15 |
+
if not mid_channels:
|
16 |
+
mid_channels = out_channels
|
17 |
+
self.double_conv = nn.Sequential(
|
18 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
19 |
+
nn.BatchNorm2d(mid_channels),
|
20 |
+
nn.ReLU(inplace=True),
|
21 |
+
nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
|
22 |
+
nn.BatchNorm2d(mid_channels),
|
23 |
+
nn.ReLU(inplace=True),
|
24 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
25 |
+
nn.BatchNorm2d(out_channels),
|
26 |
+
nn.ReLU(inplace=True)
|
27 |
+
)
|
28 |
+
self.double_conv1 = nn.Sequential(
|
29 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
30 |
+
nn.BatchNorm2d(out_channels),
|
31 |
+
nn.ReLU(inplace=True),
|
32 |
+
)
|
33 |
+
def forward(self, x):
|
34 |
+
x_in = self.double_conv1(x)
|
35 |
+
x1 = self.double_conv(x)
|
36 |
+
return self.double_conv(x) + x_in
|
37 |
+
|
38 |
+
class Down(nn.Module):
|
39 |
+
"""Downscaling with maxpool then Resconv"""
|
40 |
+
def __init__(self, in_channels, out_channels):
|
41 |
+
super().__init__()
|
42 |
+
self.maxpool_conv = nn.Sequential(
|
43 |
+
nn.MaxPool2d(2),
|
44 |
+
ResConv(in_channels, out_channels)
|
45 |
+
)
|
46 |
+
def forward(self, x):
|
47 |
+
return self.maxpool_conv(x)
|
48 |
+
|
49 |
+
class Up(nn.Module):
|
50 |
+
"""Upscaling then double conv"""
|
51 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
52 |
+
super().__init__()
|
53 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
54 |
+
if bilinear:
|
55 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
56 |
+
self.conv = ResConv(in_channels, out_channels, in_channels // 2)
|
57 |
+
else:
|
58 |
+
self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
|
59 |
+
self.conv = ResConv(in_channels, out_channels)
|
60 |
+
def forward(self, x1, x2):
|
61 |
+
x1 = self.up(x1)
|
62 |
+
# input is CHW
|
63 |
+
diffY = x2.size()[2] - x1.size()[2]
|
64 |
+
diffX = x2.size()[3] - x1.size()[3]
|
65 |
+
x1 = F.pad(
|
66 |
+
x1,
|
67 |
+
[
|
68 |
+
diffX // 2, diffX - diffX // 2,
|
69 |
+
diffY // 2, diffY - diffY // 2
|
70 |
+
]
|
71 |
+
)
|
72 |
+
# if you have padding issues, see
|
73 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
74 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
75 |
+
x = torch.cat([x2, x1], dim=1)
|
76 |
+
return self.conv(x)
|
77 |
+
|
78 |
+
class OutConv(nn.Module):
|
79 |
+
def __init__(self, in_channels, out_channels):
|
80 |
+
super(OutConv, self).__init__()
|
81 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
82 |
+
def forward(self, x):
|
83 |
+
# return F.relu(self.conv(x))
|
84 |
+
return self.conv(x)
|
85 |
+
|
86 |
+
##### The composite networks
|
87 |
+
class UNet(nn.Module):
|
88 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
89 |
+
super(UNet, self).__init__()
|
90 |
+
self.n_channels = n_channels
|
91 |
+
self.out_channels = out_channels
|
92 |
+
self.bilinear = bilinear
|
93 |
+
####
|
94 |
+
self.inc = ResConv(n_channels, 64)
|
95 |
+
self.down1 = Down(64, 128)
|
96 |
+
self.down2 = Down(128, 256)
|
97 |
+
self.down3 = Down(256, 512)
|
98 |
+
factor = 2 if bilinear else 1
|
99 |
+
self.down4 = Down(512, 1024 // factor)
|
100 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
101 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
102 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
103 |
+
self.up4 = Up(128, 64, bilinear)
|
104 |
+
self.outc = OutConv(64, out_channels)
|
105 |
+
def forward(self, x):
|
106 |
+
x1 = self.inc(x)
|
107 |
+
x2 = self.down1(x1)
|
108 |
+
x3 = self.down2(x2)
|
109 |
+
x4 = self.down3(x3)
|
110 |
+
x5 = self.down4(x4)
|
111 |
+
x = self.up1(x5, x4)
|
112 |
+
x = self.up2(x, x3)
|
113 |
+
x = self.up3(x, x2)
|
114 |
+
x = self.up4(x, x1)
|
115 |
+
y = self.outc(x)
|
116 |
+
return y
|
117 |
+
|
118 |
+
class CasUNet(nn.Module):
|
119 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
120 |
+
super(CasUNet, self).__init__()
|
121 |
+
self.n_unet = n_unet
|
122 |
+
self.io_channels = io_channels
|
123 |
+
self.bilinear = bilinear
|
124 |
+
####
|
125 |
+
self.unet_list = nn.ModuleList()
|
126 |
+
for i in range(self.n_unet):
|
127 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
128 |
+
def forward(self, x, dop=None):
|
129 |
+
y = x
|
130 |
+
for i in range(self.n_unet):
|
131 |
+
if i==0:
|
132 |
+
if dop is not None:
|
133 |
+
y = F.dropout2d(self.unet_list[i](y), p=dop)
|
134 |
+
else:
|
135 |
+
y = self.unet_list[i](y)
|
136 |
+
else:
|
137 |
+
y = self.unet_list[i](y+x)
|
138 |
+
return y
|
139 |
+
|
140 |
+
class CasUNet_2head(nn.Module):
|
141 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
142 |
+
super(CasUNet_2head, self).__init__()
|
143 |
+
self.n_unet = n_unet
|
144 |
+
self.io_channels = io_channels
|
145 |
+
self.bilinear = bilinear
|
146 |
+
####
|
147 |
+
self.unet_list = nn.ModuleList()
|
148 |
+
for i in range(self.n_unet):
|
149 |
+
if i != self.n_unet-1:
|
150 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
151 |
+
else:
|
152 |
+
self.unet_list.append(UNet_2head(self.io_channels, self.io_channels, self.bilinear))
|
153 |
+
def forward(self, x):
|
154 |
+
y = x
|
155 |
+
for i in range(self.n_unet):
|
156 |
+
if i==0:
|
157 |
+
y = self.unet_list[i](y)
|
158 |
+
else:
|
159 |
+
y = self.unet_list[i](y+x)
|
160 |
+
y_mean, y_sigma = y[0], y[1]
|
161 |
+
return y_mean, y_sigma
|
162 |
+
|
163 |
+
class CasUNet_3head(nn.Module):
|
164 |
+
def __init__(self, n_unet, io_channels, bilinear=True):
|
165 |
+
super(CasUNet_3head, self).__init__()
|
166 |
+
self.n_unet = n_unet
|
167 |
+
self.io_channels = io_channels
|
168 |
+
self.bilinear = bilinear
|
169 |
+
####
|
170 |
+
self.unet_list = nn.ModuleList()
|
171 |
+
for i in range(self.n_unet):
|
172 |
+
if i != self.n_unet-1:
|
173 |
+
self.unet_list.append(UNet(self.io_channels, self.io_channels, self.bilinear))
|
174 |
+
else:
|
175 |
+
self.unet_list.append(UNet_3head(self.io_channels, self.io_channels, self.bilinear))
|
176 |
+
def forward(self, x):
|
177 |
+
y = x
|
178 |
+
for i in range(self.n_unet):
|
179 |
+
if i==0:
|
180 |
+
y = self.unet_list[i](y)
|
181 |
+
else:
|
182 |
+
y = self.unet_list[i](y+x)
|
183 |
+
y_mean, y_alpha, y_beta = y[0], y[1], y[2]
|
184 |
+
return y_mean, y_alpha, y_beta
|
185 |
+
|
186 |
+
class UNet_2head(nn.Module):
|
187 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
188 |
+
super(UNet_2head, self).__init__()
|
189 |
+
self.n_channels = n_channels
|
190 |
+
self.out_channels = out_channels
|
191 |
+
self.bilinear = bilinear
|
192 |
+
####
|
193 |
+
self.inc = ResConv(n_channels, 64)
|
194 |
+
self.down1 = Down(64, 128)
|
195 |
+
self.down2 = Down(128, 256)
|
196 |
+
self.down3 = Down(256, 512)
|
197 |
+
factor = 2 if bilinear else 1
|
198 |
+
self.down4 = Down(512, 1024 // factor)
|
199 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
200 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
201 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
202 |
+
self.up4 = Up(128, 64, bilinear)
|
203 |
+
#per pixel multiple channels may exist
|
204 |
+
self.out_mean = OutConv(64, out_channels)
|
205 |
+
#variance will always be a single number for a pixel
|
206 |
+
self.out_var = nn.Sequential(
|
207 |
+
OutConv(64, 128),
|
208 |
+
OutConv(128, 1),
|
209 |
+
)
|
210 |
+
def forward(self, x):
|
211 |
+
x1 = self.inc(x)
|
212 |
+
x2 = self.down1(x1)
|
213 |
+
x3 = self.down2(x2)
|
214 |
+
x4 = self.down3(x3)
|
215 |
+
x5 = self.down4(x4)
|
216 |
+
x = self.up1(x5, x4)
|
217 |
+
x = self.up2(x, x3)
|
218 |
+
x = self.up3(x, x2)
|
219 |
+
x = self.up4(x, x1)
|
220 |
+
y_mean, y_var = self.out_mean(x), self.out_var(x)
|
221 |
+
return y_mean, y_var
|
222 |
+
|
223 |
+
class UNet_3head(nn.Module):
|
224 |
+
def __init__(self, n_channels, out_channels, bilinear=True):
|
225 |
+
super(UNet_3head, self).__init__()
|
226 |
+
self.n_channels = n_channels
|
227 |
+
self.out_channels = out_channels
|
228 |
+
self.bilinear = bilinear
|
229 |
+
####
|
230 |
+
self.inc = ResConv(n_channels, 64)
|
231 |
+
self.down1 = Down(64, 128)
|
232 |
+
self.down2 = Down(128, 256)
|
233 |
+
self.down3 = Down(256, 512)
|
234 |
+
factor = 2 if bilinear else 1
|
235 |
+
self.down4 = Down(512, 1024 // factor)
|
236 |
+
self.up1 = Up(1024, 512 // factor, bilinear)
|
237 |
+
self.up2 = Up(512, 256 // factor, bilinear)
|
238 |
+
self.up3 = Up(256, 128 // factor, bilinear)
|
239 |
+
self.up4 = Up(128, 64, bilinear)
|
240 |
+
#per pixel multiple channels may exist
|
241 |
+
self.out_mean = OutConv(64, out_channels)
|
242 |
+
#variance will always be a single number for a pixel
|
243 |
+
self.out_alpha = nn.Sequential(
|
244 |
+
OutConv(64, 128),
|
245 |
+
OutConv(128, 1),
|
246 |
+
nn.ReLU()
|
247 |
+
)
|
248 |
+
self.out_beta = nn.Sequential(
|
249 |
+
OutConv(64, 128),
|
250 |
+
OutConv(128, 1),
|
251 |
+
nn.ReLU()
|
252 |
+
)
|
253 |
+
def forward(self, x):
|
254 |
+
x1 = self.inc(x)
|
255 |
+
x2 = self.down1(x1)
|
256 |
+
x3 = self.down2(x2)
|
257 |
+
x4 = self.down3(x3)
|
258 |
+
x5 = self.down4(x4)
|
259 |
+
x = self.up1(x5, x4)
|
260 |
+
x = self.up2(x, x3)
|
261 |
+
x = self.up3(x, x2)
|
262 |
+
x = self.up4(x, x1)
|
263 |
+
y_mean, y_alpha, y_beta = self.out_mean(x), \
|
264 |
+
self.out_alpha(x), self.out_beta(x)
|
265 |
+
return y_mean, y_alpha, y_beta
|
266 |
+
|
267 |
+
class ResidualBlock(nn.Module):
|
268 |
+
def __init__(self, in_features):
|
269 |
+
super(ResidualBlock, self).__init__()
|
270 |
+
conv_block = [
|
271 |
+
nn.ReflectionPad2d(1),
|
272 |
+
nn.Conv2d(in_features, in_features, 3),
|
273 |
+
nn.InstanceNorm2d(in_features),
|
274 |
+
nn.ReLU(inplace=True),
|
275 |
+
nn.ReflectionPad2d(1),
|
276 |
+
nn.Conv2d(in_features, in_features, 3),
|
277 |
+
nn.InstanceNorm2d(in_features)
|
278 |
+
]
|
279 |
+
self.conv_block = nn.Sequential(*conv_block)
|
280 |
+
def forward(self, x):
|
281 |
+
return x + self.conv_block(x)
|
282 |
+
|
283 |
+
class Generator(nn.Module):
|
284 |
+
def __init__(self, input_nc, output_nc, n_residual_blocks=9):
|
285 |
+
super(Generator, self).__init__()
|
286 |
+
# Initial convolution block
|
287 |
+
model = [
|
288 |
+
nn.ReflectionPad2d(3), nn.Conv2d(input_nc, 64, 7),
|
289 |
+
nn.InstanceNorm2d(64), nn.ReLU(inplace=True)
|
290 |
+
]
|
291 |
+
# Downsampling
|
292 |
+
in_features = 64
|
293 |
+
out_features = in_features*2
|
294 |
+
for _ in range(2):
|
295 |
+
model += [
|
296 |
+
nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
|
297 |
+
nn.InstanceNorm2d(out_features),
|
298 |
+
nn.ReLU(inplace=True)
|
299 |
+
]
|
300 |
+
in_features = out_features
|
301 |
+
out_features = in_features*2
|
302 |
+
# Residual blocks
|
303 |
+
for _ in range(n_residual_blocks):
|
304 |
+
model += [ResidualBlock(in_features)]
|
305 |
+
# Upsampling
|
306 |
+
out_features = in_features//2
|
307 |
+
for _ in range(2):
|
308 |
+
model += [
|
309 |
+
nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
|
310 |
+
nn.InstanceNorm2d(out_features),
|
311 |
+
nn.ReLU(inplace=True)
|
312 |
+
]
|
313 |
+
in_features = out_features
|
314 |
+
out_features = in_features//2
|
315 |
+
# Output layer
|
316 |
+
model += [nn.ReflectionPad2d(3), nn.Conv2d(64, output_nc, 7), nn.Tanh()]
|
317 |
+
self.model = nn.Sequential(*model)
|
318 |
+
def forward(self, x):
|
319 |
+
return self.model(x)
|
320 |
+
|
321 |
+
|
322 |
+
class ResnetGenerator(nn.Module):
|
323 |
+
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
324 |
+
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
325 |
+
"""
|
326 |
+
|
327 |
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'):
|
328 |
+
"""Construct a Resnet-based generator
|
329 |
+
Parameters:
|
330 |
+
input_nc (int) -- the number of channels in input images
|
331 |
+
output_nc (int) -- the number of channels in output images
|
332 |
+
ngf (int) -- the number of filters in the last conv layer
|
333 |
+
norm_layer -- normalization layer
|
334 |
+
use_dropout (bool) -- if use dropout layers
|
335 |
+
n_blocks (int) -- the number of ResNet blocks
|
336 |
+
padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
|
337 |
+
"""
|
338 |
+
assert(n_blocks >= 0)
|
339 |
+
super(ResnetGenerator, self).__init__()
|
340 |
+
if type(norm_layer) == functools.partial:
|
341 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
342 |
+
else:
|
343 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
344 |
+
|
345 |
+
model = [nn.ReflectionPad2d(3),
|
346 |
+
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
347 |
+
norm_layer(ngf),
|
348 |
+
nn.ReLU(True)]
|
349 |
+
|
350 |
+
n_downsampling = 2
|
351 |
+
for i in range(n_downsampling): # add downsampling layers
|
352 |
+
mult = 2 ** i
|
353 |
+
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
354 |
+
norm_layer(ngf * mult * 2),
|
355 |
+
nn.ReLU(True)]
|
356 |
+
|
357 |
+
mult = 2 ** n_downsampling
|
358 |
+
for i in range(n_blocks): # add ResNet blocks
|
359 |
+
|
360 |
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
|
361 |
+
|
362 |
+
for i in range(n_downsampling): # add upsampling layers
|
363 |
+
mult = 2 ** (n_downsampling - i)
|
364 |
+
model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
365 |
+
kernel_size=3, stride=2,
|
366 |
+
padding=1, output_padding=1,
|
367 |
+
bias=use_bias),
|
368 |
+
norm_layer(int(ngf * mult / 2)),
|
369 |
+
nn.ReLU(True)]
|
370 |
+
model += [nn.ReflectionPad2d(3)]
|
371 |
+
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
372 |
+
model += [nn.Tanh()]
|
373 |
+
|
374 |
+
self.model = nn.Sequential(*model)
|
375 |
+
|
376 |
+
def forward(self, input):
|
377 |
+
"""Standard forward"""
|
378 |
+
return self.model(input)
|
379 |
+
|
380 |
+
|
381 |
+
class ResnetBlock(nn.Module):
|
382 |
+
"""Define a Resnet block"""
|
383 |
+
|
384 |
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
385 |
+
"""Initialize the Resnet block
|
386 |
+
A resnet block is a conv block with skip connections
|
387 |
+
We construct a conv block with build_conv_block function,
|
388 |
+
and implement skip connections in <forward> function.
|
389 |
+
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
390 |
+
"""
|
391 |
+
super(ResnetBlock, self).__init__()
|
392 |
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
393 |
+
|
394 |
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
395 |
+
"""Construct a convolutional block.
|
396 |
+
Parameters:
|
397 |
+
dim (int) -- the number of channels in the conv layer.
|
398 |
+
padding_type (str) -- the name of padding layer: reflect | replicate | zero
|
399 |
+
norm_layer -- normalization layer
|
400 |
+
use_dropout (bool) -- if use dropout layers.
|
401 |
+
use_bias (bool) -- if the conv layer uses bias or not
|
402 |
+
Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU))
|
403 |
+
"""
|
404 |
+
conv_block = []
|
405 |
+
p = 0
|
406 |
+
if padding_type == 'reflect':
|
407 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
408 |
+
elif padding_type == 'replicate':
|
409 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
410 |
+
elif padding_type == 'zero':
|
411 |
+
p = 1
|
412 |
+
else:
|
413 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
414 |
+
|
415 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
|
416 |
+
if use_dropout:
|
417 |
+
conv_block += [nn.Dropout(0.5)]
|
418 |
+
|
419 |
+
p = 0
|
420 |
+
if padding_type == 'reflect':
|
421 |
+
conv_block += [nn.ReflectionPad2d(1)]
|
422 |
+
elif padding_type == 'replicate':
|
423 |
+
conv_block += [nn.ReplicationPad2d(1)]
|
424 |
+
elif padding_type == 'zero':
|
425 |
+
p = 1
|
426 |
+
else:
|
427 |
+
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
428 |
+
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
|
429 |
+
|
430 |
+
return nn.Sequential(*conv_block)
|
431 |
+
|
432 |
+
def forward(self, x):
|
433 |
+
"""Forward function (with skip connections)"""
|
434 |
+
out = x + self.conv_block(x) # add skip connections
|
435 |
+
return out
|
436 |
+
|
437 |
+
### discriminator
|
438 |
+
class NLayerDiscriminator(nn.Module):
|
439 |
+
"""Defines a PatchGAN discriminator"""
|
440 |
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
441 |
+
"""Construct a PatchGAN discriminator
|
442 |
+
Parameters:
|
443 |
+
input_nc (int) -- the number of channels in input images
|
444 |
+
ndf (int) -- the number of filters in the last conv layer
|
445 |
+
n_layers (int) -- the number of conv layers in the discriminator
|
446 |
+
norm_layer -- normalization layer
|
447 |
+
"""
|
448 |
+
super(NLayerDiscriminator, self).__init__()
|
449 |
+
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
450 |
+
use_bias = norm_layer.func == nn.InstanceNorm2d
|
451 |
+
else:
|
452 |
+
use_bias = norm_layer == nn.InstanceNorm2d
|
453 |
+
kw = 4
|
454 |
+
padw = 1
|
455 |
+
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
456 |
+
nf_mult = 1
|
457 |
+
nf_mult_prev = 1
|
458 |
+
for n in range(1, n_layers): # gradually increase the number of filters
|
459 |
+
nf_mult_prev = nf_mult
|
460 |
+
nf_mult = min(2 ** n, 8)
|
461 |
+
sequence += [
|
462 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
463 |
+
norm_layer(ndf * nf_mult),
|
464 |
+
nn.LeakyReLU(0.2, True)
|
465 |
+
]
|
466 |
+
nf_mult_prev = nf_mult
|
467 |
+
nf_mult = min(2 ** n_layers, 8)
|
468 |
+
sequence += [
|
469 |
+
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
470 |
+
norm_layer(ndf * nf_mult),
|
471 |
+
nn.LeakyReLU(0.2, True)
|
472 |
+
]
|
473 |
+
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
474 |
+
self.model = nn.Sequential(*sequence)
|
475 |
+
def forward(self, input):
|
476 |
+
"""Standard forward."""
|
477 |
+
return self.model(input)
|
src/utils.py
ADDED
@@ -0,0 +1,1273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Any, Optional
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
from glob import glob
|
7 |
+
from PIL import Image, ImageDraw
|
8 |
+
from tqdm import tqdm
|
9 |
+
import kornia
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import seaborn as sns
|
12 |
+
import albumentations as albu
|
13 |
+
import functools
|
14 |
+
import math
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from torch import Tensor
|
19 |
+
import torchvision as tv
|
20 |
+
import torchvision.models as models
|
21 |
+
from torchvision import transforms
|
22 |
+
from torchvision.transforms import functional as F
|
23 |
+
from losses import TempCombLoss
|
24 |
+
|
25 |
+
########### DeblurGAN function
|
26 |
+
def get_norm_layer(norm_type='instance'):
|
27 |
+
if norm_type == 'batch':
|
28 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
29 |
+
elif norm_type == 'instance':
|
30 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
|
31 |
+
else:
|
32 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
33 |
+
return norm_layer
|
34 |
+
|
35 |
+
def _array_to_batch(x):
|
36 |
+
x = np.transpose(x, (2, 0, 1))
|
37 |
+
x = np.expand_dims(x, 0)
|
38 |
+
return torch.from_numpy(x)
|
39 |
+
|
40 |
+
def get_normalize():
|
41 |
+
normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
42 |
+
normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
|
43 |
+
|
44 |
+
def process(a, b):
|
45 |
+
r = normalize(image=a, target=b)
|
46 |
+
return r['image'], r['target']
|
47 |
+
|
48 |
+
return process
|
49 |
+
|
50 |
+
def preprocess(x: np.ndarray, mask: Optional[np.ndarray]):
|
51 |
+
x, _ = get_normalize()(x, x)
|
52 |
+
if mask is None:
|
53 |
+
mask = np.ones_like(x, dtype=np.float32)
|
54 |
+
else:
|
55 |
+
mask = np.round(mask.astype('float32') / 255)
|
56 |
+
|
57 |
+
h, w, _ = x.shape
|
58 |
+
block_size = 32
|
59 |
+
min_height = (h // block_size + 1) * block_size
|
60 |
+
min_width = (w // block_size + 1) * block_size
|
61 |
+
|
62 |
+
pad_params = {'mode': 'constant',
|
63 |
+
'constant_values': 0,
|
64 |
+
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
|
65 |
+
}
|
66 |
+
x = np.pad(x, **pad_params)
|
67 |
+
mask = np.pad(mask, **pad_params)
|
68 |
+
|
69 |
+
return map(_array_to_batch, (x, mask)), h, w
|
70 |
+
|
71 |
+
def postprocess(x: torch.Tensor) -> np.ndarray:
|
72 |
+
x, = x
|
73 |
+
x = x.detach().cpu().float().numpy()
|
74 |
+
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
|
75 |
+
return x.astype('uint8')
|
76 |
+
|
77 |
+
def sorted_glob(pattern):
|
78 |
+
return sorted(glob(pattern))
|
79 |
+
###########
|
80 |
+
|
81 |
+
def normalize(image: np.ndarray) -> np.ndarray:
|
82 |
+
"""Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
83 |
+
Args:
|
84 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
85 |
+
Returns:
|
86 |
+
Normalized image data. Data range [0, 1].
|
87 |
+
"""
|
88 |
+
return image.astype(np.float64) / 255.0
|
89 |
+
|
90 |
+
|
91 |
+
def unnormalize(image: np.ndarray) -> np.ndarray:
|
92 |
+
"""Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
93 |
+
Args:
|
94 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
95 |
+
Returns:
|
96 |
+
Denormalized image data. Data range [0, 255].
|
97 |
+
"""
|
98 |
+
return image.astype(np.float64) * 255.0
|
99 |
+
|
100 |
+
|
101 |
+
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
|
102 |
+
"""Convert ``PIL.Image`` to Tensor.
|
103 |
+
Args:
|
104 |
+
image (np.ndarray): The image data read by ``PIL.Image``
|
105 |
+
range_norm (bool): Scale [0, 1] data to between [-1, 1]
|
106 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
107 |
+
Returns:
|
108 |
+
Normalized image data
|
109 |
+
Examples:
|
110 |
+
>>> image = Image.open("image.bmp")
|
111 |
+
>>> tensor_image = image2tensor(image, range_norm=False, half=False)
|
112 |
+
"""
|
113 |
+
tensor = F.to_tensor(image)
|
114 |
+
|
115 |
+
if range_norm:
|
116 |
+
tensor = tensor.mul_(2.0).sub_(1.0)
|
117 |
+
if half:
|
118 |
+
tensor = tensor.half()
|
119 |
+
|
120 |
+
return tensor
|
121 |
+
|
122 |
+
|
123 |
+
def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
|
124 |
+
"""Converts ``torch.Tensor`` to ``PIL.Image``.
|
125 |
+
Args:
|
126 |
+
tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
|
127 |
+
range_norm (bool): Scale [-1, 1] data to between [0, 1]
|
128 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
129 |
+
Returns:
|
130 |
+
Convert image data to support PIL library
|
131 |
+
Examples:
|
132 |
+
>>> tensor = torch.randn([1, 3, 128, 128])
|
133 |
+
>>> image = tensor2image(tensor, range_norm=False, half=False)
|
134 |
+
"""
|
135 |
+
if range_norm:
|
136 |
+
tensor = tensor.add_(1.0).div_(2.0)
|
137 |
+
if half:
|
138 |
+
tensor = tensor.half()
|
139 |
+
|
140 |
+
image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
|
141 |
+
|
142 |
+
return image
|
143 |
+
|
144 |
+
|
145 |
+
def convert_rgb_to_y(image: Any) -> Any:
|
146 |
+
"""Convert RGB image or tensor image data to YCbCr(Y) format.
|
147 |
+
Args:
|
148 |
+
image: RGB image data read by ``PIL.Image''.
|
149 |
+
Returns:
|
150 |
+
Y image array data.
|
151 |
+
"""
|
152 |
+
if type(image) == np.ndarray:
|
153 |
+
return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
|
154 |
+
elif type(image) == torch.Tensor:
|
155 |
+
if len(image.shape) == 4:
|
156 |
+
image = image.squeeze_(0)
|
157 |
+
return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
|
158 |
+
else:
|
159 |
+
raise Exception("Unknown Type", type(image))
|
160 |
+
|
161 |
+
|
162 |
+
def convert_rgb_to_ycbcr(image: Any) -> Any:
|
163 |
+
"""Convert RGB image or tensor image data to YCbCr format.
|
164 |
+
Args:
|
165 |
+
image: RGB image data read by ``PIL.Image''.
|
166 |
+
Returns:
|
167 |
+
YCbCr image array data.
|
168 |
+
"""
|
169 |
+
if type(image) == np.ndarray:
|
170 |
+
y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
|
171 |
+
cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256.
|
172 |
+
cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256.
|
173 |
+
return np.array([y, cb, cr]).transpose([1, 2, 0])
|
174 |
+
elif type(image) == torch.Tensor:
|
175 |
+
if len(image.shape) == 4:
|
176 |
+
image = image.squeeze(0)
|
177 |
+
y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
|
178 |
+
cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256.
|
179 |
+
cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256.
|
180 |
+
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
|
181 |
+
else:
|
182 |
+
raise Exception("Unknown Type", type(image))
|
183 |
+
|
184 |
+
|
185 |
+
def convert_ycbcr_to_rgb(image: Any) -> Any:
|
186 |
+
"""Convert YCbCr format image to RGB format.
|
187 |
+
Args:
|
188 |
+
image: YCbCr image data read by ``PIL.Image''.
|
189 |
+
Returns:
|
190 |
+
RGB image array data.
|
191 |
+
"""
|
192 |
+
if type(image) == np.ndarray:
|
193 |
+
r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921
|
194 |
+
g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576
|
195 |
+
b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836
|
196 |
+
return np.array([r, g, b]).transpose([1, 2, 0])
|
197 |
+
elif type(image) == torch.Tensor:
|
198 |
+
if len(image.shape) == 4:
|
199 |
+
image = image.squeeze(0)
|
200 |
+
r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921
|
201 |
+
g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576
|
202 |
+
b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836
|
203 |
+
return torch.cat([r, g, b], 0).permute(1, 2, 0)
|
204 |
+
else:
|
205 |
+
raise Exception("Unknown Type", type(image))
|
206 |
+
|
207 |
+
|
208 |
+
def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
|
209 |
+
"""Cut ``PIL.Image`` in the center area of the image.
|
210 |
+
Args:
|
211 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
212 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
213 |
+
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
|
214 |
+
upscale_factor (int): magnification factor.
|
215 |
+
Returns:
|
216 |
+
Randomly cropped low-resolution images and high-resolution images.
|
217 |
+
"""
|
218 |
+
w, h = hr.size
|
219 |
+
|
220 |
+
left = (w - image_size) // 2
|
221 |
+
top = (h - image_size) // 2
|
222 |
+
right = left + image_size
|
223 |
+
bottom = top + image_size
|
224 |
+
|
225 |
+
lr = lr.crop((left // upscale_factor,
|
226 |
+
top // upscale_factor,
|
227 |
+
right // upscale_factor,
|
228 |
+
bottom // upscale_factor))
|
229 |
+
hr = hr.crop((left, top, right, bottom))
|
230 |
+
|
231 |
+
return lr, hr
|
232 |
+
|
233 |
+
|
234 |
+
def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
|
235 |
+
"""Will ``PIL.Image`` randomly capture the specified area of the image.
|
236 |
+
Args:
|
237 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
238 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
239 |
+
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
|
240 |
+
upscale_factor (int): magnification factor.
|
241 |
+
Returns:
|
242 |
+
Randomly cropped low-resolution images and high-resolution images.
|
243 |
+
"""
|
244 |
+
w, h = hr.size
|
245 |
+
left = torch.randint(0, w - image_size + 1, size=(1,)).item()
|
246 |
+
top = torch.randint(0, h - image_size + 1, size=(1,)).item()
|
247 |
+
right = left + image_size
|
248 |
+
bottom = top + image_size
|
249 |
+
|
250 |
+
lr = lr.crop((left // upscale_factor,
|
251 |
+
top // upscale_factor,
|
252 |
+
right // upscale_factor,
|
253 |
+
bottom // upscale_factor))
|
254 |
+
hr = hr.crop((left, top, right, bottom))
|
255 |
+
|
256 |
+
return lr, hr
|
257 |
+
|
258 |
+
|
259 |
+
def random_rotate(lr: Any, hr: Any, angle: int) -> [Any, Any]:
|
260 |
+
"""Will ``PIL.Image`` randomly rotate the image.
|
261 |
+
Args:
|
262 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
263 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
264 |
+
angle (int): rotation angle, clockwise and counterclockwise rotation.
|
265 |
+
Returns:
|
266 |
+
Randomly rotated low-resolution images and high-resolution images.
|
267 |
+
"""
|
268 |
+
angle = random.choice((+angle, -angle))
|
269 |
+
lr = F.rotate(lr, angle)
|
270 |
+
hr = F.rotate(hr, angle)
|
271 |
+
|
272 |
+
return lr, hr
|
273 |
+
|
274 |
+
|
275 |
+
def random_horizontally_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
|
276 |
+
"""Flip the ``PIL.Image`` image horizontally randomly.
|
277 |
+
Args:
|
278 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
279 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
280 |
+
p (optional, float): rollover probability. (Default: 0.5)
|
281 |
+
Returns:
|
282 |
+
Low-resolution image and high-resolution image after random horizontal flip.
|
283 |
+
"""
|
284 |
+
if torch.rand(1).item() > p:
|
285 |
+
lr = F.hflip(lr)
|
286 |
+
hr = F.hflip(hr)
|
287 |
+
|
288 |
+
return lr, hr
|
289 |
+
|
290 |
+
|
291 |
+
def random_vertically_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
|
292 |
+
"""Turn the ``PIL.Image`` image upside down randomly.
|
293 |
+
Args:
|
294 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
295 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
296 |
+
p (optional, float): rollover probability. (Default: 0.5)
|
297 |
+
Returns:
|
298 |
+
Randomly rotated up and down low-resolution images and high-resolution images.
|
299 |
+
"""
|
300 |
+
if torch.rand(1).item() > p:
|
301 |
+
lr = F.vflip(lr)
|
302 |
+
hr = F.vflip(hr)
|
303 |
+
|
304 |
+
return lr, hr
|
305 |
+
|
306 |
+
|
307 |
+
def random_adjust_brightness(lr: Any, hr: Any) -> [Any, Any]:
|
308 |
+
"""Set ``PIL.Image`` to randomly adjust the image brightness.
|
309 |
+
Args:
|
310 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
311 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
312 |
+
Returns:
|
313 |
+
Low-resolution image and high-resolution image with randomly adjusted brightness.
|
314 |
+
"""
|
315 |
+
# Randomly adjust the brightness gain range.
|
316 |
+
factor = random.uniform(0.5, 2)
|
317 |
+
lr = F.adjust_brightness(lr, factor)
|
318 |
+
hr = F.adjust_brightness(hr, factor)
|
319 |
+
|
320 |
+
return lr, hr
|
321 |
+
|
322 |
+
|
323 |
+
def random_adjust_contrast(lr: Any, hr: Any) -> [Any, Any]:
|
324 |
+
"""Set ``PIL.Image`` to randomly adjust the image contrast.
|
325 |
+
Args:
|
326 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
327 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
328 |
+
Returns:
|
329 |
+
Low-resolution image and high-resolution image with randomly adjusted contrast.
|
330 |
+
"""
|
331 |
+
# Randomly adjust the contrast gain range.
|
332 |
+
factor = random.uniform(0.5, 2)
|
333 |
+
lr = F.adjust_contrast(lr, factor)
|
334 |
+
hr = F.adjust_contrast(hr, factor)
|
335 |
+
|
336 |
+
return lr, hr
|
337 |
+
|
338 |
+
#### metrics to compute -- assumes single images, i.e., tensor of 3 dims
|
339 |
+
def img_mae(x1, x2):
|
340 |
+
m = torch.abs(x1-x2).mean()
|
341 |
+
return m
|
342 |
+
|
343 |
+
def img_mse(x1, x2):
|
344 |
+
m = torch.pow(torch.abs(x1-x2),2).mean()
|
345 |
+
return m
|
346 |
+
|
347 |
+
def img_psnr(x1, x2):
|
348 |
+
m = kornia.metrics.psnr(x1, x2, 1)
|
349 |
+
return m
|
350 |
+
|
351 |
+
def img_ssim(x1, x2):
|
352 |
+
m = kornia.metrics.ssim(x1.unsqueeze(0), x2.unsqueeze(0), 5)
|
353 |
+
m = m.mean()
|
354 |
+
return m
|
355 |
+
|
356 |
+
def show_SR_w_uncer(xLR, xHR, xSR, xSRvar, elim=(0,0.01), ulim=(0,0.15)):
|
357 |
+
'''
|
358 |
+
xLR/SR/HR: 3xHxW
|
359 |
+
xSRvar: 1xHxW
|
360 |
+
'''
|
361 |
+
plt.figure(figsize=(30,10))
|
362 |
+
|
363 |
+
plt.subplot(1,5,1)
|
364 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
365 |
+
plt.axis('off')
|
366 |
+
|
367 |
+
plt.subplot(1,5,2)
|
368 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
369 |
+
plt.axis('off')
|
370 |
+
|
371 |
+
plt.subplot(1,5,3)
|
372 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
373 |
+
plt.axis('off')
|
374 |
+
|
375 |
+
plt.subplot(1,5,4)
|
376 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
|
377 |
+
print('error', error_map.min(), error_map.max())
|
378 |
+
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
|
379 |
+
plt.clim(elim[0], elim[1])
|
380 |
+
plt.axis('off')
|
381 |
+
|
382 |
+
plt.subplot(1,5,5)
|
383 |
+
print('uncer', xSRvar.min(), xSRvar.max())
|
384 |
+
plt.imshow(xSRvar.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
385 |
+
plt.clim(ulim[0], ulim[1])
|
386 |
+
plt.axis('off')
|
387 |
+
|
388 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
389 |
+
plt.show()
|
390 |
+
|
391 |
+
def show_SR_w_err(xLR, xHR, xSR, elim=(0,0.01), task=None, xMask=None):
|
392 |
+
'''
|
393 |
+
xLR/SR/HR: 3xHxW
|
394 |
+
'''
|
395 |
+
plt.figure(figsize=(30,10))
|
396 |
+
|
397 |
+
if task != 'm':
|
398 |
+
plt.subplot(1,4,1)
|
399 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
400 |
+
plt.axis('off')
|
401 |
+
|
402 |
+
plt.subplot(1,4,2)
|
403 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
404 |
+
plt.axis('off')
|
405 |
+
|
406 |
+
plt.subplot(1,4,3)
|
407 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
408 |
+
plt.axis('off')
|
409 |
+
else:
|
410 |
+
plt.subplot(1,4,1)
|
411 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
412 |
+
plt.clim(0,0.9)
|
413 |
+
plt.axis('off')
|
414 |
+
|
415 |
+
plt.subplot(1,4,2)
|
416 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
417 |
+
plt.clim(0,0.9)
|
418 |
+
plt.axis('off')
|
419 |
+
|
420 |
+
plt.subplot(1,4,3)
|
421 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
422 |
+
plt.clim(0,0.9)
|
423 |
+
plt.axis('off')
|
424 |
+
|
425 |
+
plt.subplot(1,4,4)
|
426 |
+
if task == 'inpainting':
|
427 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)*xMask.to('cpu').data
|
428 |
+
else:
|
429 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
|
430 |
+
print('error', error_map.min(), error_map.max())
|
431 |
+
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
|
432 |
+
plt.clim(elim[0], elim[1])
|
433 |
+
plt.axis('off')
|
434 |
+
|
435 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
436 |
+
plt.show()
|
437 |
+
|
438 |
+
def show_uncer4(xSRvar1, xSRvar2, xSRvar3, xSRvar4, ulim=(0,0.15)):
|
439 |
+
'''
|
440 |
+
xSRvar: 1xHxW
|
441 |
+
'''
|
442 |
+
plt.figure(figsize=(30,10))
|
443 |
+
|
444 |
+
plt.subplot(1,4,1)
|
445 |
+
print('uncer', xSRvar1.min(), xSRvar1.max())
|
446 |
+
plt.imshow(xSRvar1.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
447 |
+
plt.clim(ulim[0], ulim[1])
|
448 |
+
plt.axis('off')
|
449 |
+
|
450 |
+
plt.subplot(1,4,2)
|
451 |
+
print('uncer', xSRvar2.min(), xSRvar2.max())
|
452 |
+
plt.imshow(xSRvar2.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
453 |
+
plt.clim(ulim[0], ulim[1])
|
454 |
+
plt.axis('off')
|
455 |
+
|
456 |
+
plt.subplot(1,4,3)
|
457 |
+
print('uncer', xSRvar3.min(), xSRvar3.max())
|
458 |
+
plt.imshow(xSRvar3.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
459 |
+
plt.clim(ulim[0], ulim[1])
|
460 |
+
plt.axis('off')
|
461 |
+
|
462 |
+
plt.subplot(1,4,4)
|
463 |
+
print('uncer', xSRvar4.min(), xSRvar4.max())
|
464 |
+
plt.imshow(xSRvar4.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
465 |
+
plt.clim(ulim[0], ulim[1])
|
466 |
+
plt.axis('off')
|
467 |
+
|
468 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
469 |
+
plt.show()
|
470 |
+
|
471 |
+
def get_UCE(list_err, list_yout_var, num_bins=100):
|
472 |
+
err_min = np.min(list_err)
|
473 |
+
err_max = np.max(list_err)
|
474 |
+
err_len = (err_max-err_min)/num_bins
|
475 |
+
num_points = len(list_err)
|
476 |
+
|
477 |
+
bin_stats = {}
|
478 |
+
for i in range(num_bins):
|
479 |
+
bin_stats[i] = {
|
480 |
+
'start_idx': err_min + i*err_len,
|
481 |
+
'end_idx': err_min + (i+1)*err_len,
|
482 |
+
'num_points': 0,
|
483 |
+
'mean_err': 0,
|
484 |
+
'mean_var': 0,
|
485 |
+
}
|
486 |
+
|
487 |
+
for e,v in zip(list_err, list_yout_var):
|
488 |
+
for i in range(num_bins):
|
489 |
+
if e>=bin_stats[i]['start_idx'] and e<bin_stats[i]['end_idx']:
|
490 |
+
bin_stats[i]['num_points'] += 1
|
491 |
+
bin_stats[i]['mean_err'] += e
|
492 |
+
bin_stats[i]['mean_var'] += v
|
493 |
+
|
494 |
+
uce = 0
|
495 |
+
eps = 1e-8
|
496 |
+
for i in range(num_bins):
|
497 |
+
bin_stats[i]['mean_err'] /= bin_stats[i]['num_points'] + eps
|
498 |
+
bin_stats[i]['mean_var'] /= bin_stats[i]['num_points'] + eps
|
499 |
+
bin_stats[i]['uce_bin'] = (bin_stats[i]['num_points']/num_points) \
|
500 |
+
*(np.abs(bin_stats[i]['mean_err'] - bin_stats[i]['mean_var']))
|
501 |
+
uce += bin_stats[i]['uce_bin']
|
502 |
+
|
503 |
+
list_x, list_y = [], []
|
504 |
+
for i in range(num_bins):
|
505 |
+
if bin_stats[i]['num_points']>0:
|
506 |
+
list_x.append(bin_stats[i]['mean_err'])
|
507 |
+
list_y.append(bin_stats[i]['mean_var'])
|
508 |
+
|
509 |
+
# sns.set_style('darkgrid')
|
510 |
+
# sns.scatterplot(x=list_x, y=list_y)
|
511 |
+
# sns.regplot(x=list_x, y=list_y, order=1)
|
512 |
+
# plt.xlabel('MSE', fontsize=34)
|
513 |
+
# plt.ylabel('Uncertainty', fontsize=34)
|
514 |
+
# plt.plot(list_x, list_x, color='r')
|
515 |
+
# plt.xlim(np.min(list_x), np.max(list_x))
|
516 |
+
# plt.ylim(np.min(list_err), np.max(list_x))
|
517 |
+
# plt.show()
|
518 |
+
|
519 |
+
return bin_stats, uce
|
520 |
+
|
521 |
+
##################### training BayesCap
|
522 |
+
def train_BayesCap(
|
523 |
+
NetC,
|
524 |
+
NetG,
|
525 |
+
train_loader,
|
526 |
+
eval_loader,
|
527 |
+
Cri = TempCombLoss(),
|
528 |
+
device='cuda',
|
529 |
+
dtype=torch.cuda.FloatTensor(),
|
530 |
+
init_lr=1e-4,
|
531 |
+
num_epochs=100,
|
532 |
+
eval_every=1,
|
533 |
+
ckpt_path='../ckpt/BayesCap',
|
534 |
+
T1=1e0,
|
535 |
+
T2=5e-2,
|
536 |
+
task=None,
|
537 |
+
):
|
538 |
+
NetC.to(device)
|
539 |
+
NetC.train()
|
540 |
+
NetG.to(device)
|
541 |
+
NetG.eval()
|
542 |
+
optimizer = torch.optim.Adam(list(NetC.parameters()), lr=init_lr)
|
543 |
+
optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
|
544 |
+
|
545 |
+
score = -1e8
|
546 |
+
all_loss = []
|
547 |
+
for eph in range(num_epochs):
|
548 |
+
eph_loss = 0
|
549 |
+
with tqdm(train_loader, unit='batch') as tepoch:
|
550 |
+
for (idx, batch) in enumerate(tepoch):
|
551 |
+
if idx>2000:
|
552 |
+
break
|
553 |
+
tepoch.set_description('Epoch {}'.format(eph))
|
554 |
+
##
|
555 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
556 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
557 |
+
if task == 'inpainting':
|
558 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
559 |
+
xMask = xMask.to(device).type(dtype)
|
560 |
+
# pass them through the network
|
561 |
+
with torch.no_grad():
|
562 |
+
if task == 'inpainting':
|
563 |
+
_, xSR1 = NetG(xLR, xMask)
|
564 |
+
elif task == 'depth':
|
565 |
+
xSR1 = NetG(xLR)[("disp", 0)]
|
566 |
+
else:
|
567 |
+
xSR1 = NetG(xLR)
|
568 |
+
# with torch.autograd.set_detect_anomaly(True):
|
569 |
+
xSR = xSR1.clone()
|
570 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
571 |
+
# print(xSRC_alpha)
|
572 |
+
optimizer.zero_grad()
|
573 |
+
if task == 'depth':
|
574 |
+
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xSR, T1=T1, T2=T2)
|
575 |
+
else:
|
576 |
+
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xHR, T1=T1, T2=T2)
|
577 |
+
# print(loss)
|
578 |
+
loss.backward()
|
579 |
+
optimizer.step()
|
580 |
+
##
|
581 |
+
eph_loss += loss.item()
|
582 |
+
tepoch.set_postfix(loss=loss.item())
|
583 |
+
eph_loss /= len(train_loader)
|
584 |
+
all_loss.append(eph_loss)
|
585 |
+
print('Avg. loss: {}'.format(eph_loss))
|
586 |
+
# evaluate and save the models
|
587 |
+
torch.save(NetC.state_dict(), ckpt_path+'_last.pth')
|
588 |
+
if eph%eval_every == 0:
|
589 |
+
curr_score = eval_BayesCap(
|
590 |
+
NetC,
|
591 |
+
NetG,
|
592 |
+
eval_loader,
|
593 |
+
device=device,
|
594 |
+
dtype=dtype,
|
595 |
+
task=task,
|
596 |
+
)
|
597 |
+
print('current score: {} | Last best score: {}'.format(curr_score, score))
|
598 |
+
if curr_score >= score:
|
599 |
+
score = curr_score
|
600 |
+
torch.save(NetC.state_dict(), ckpt_path+'_best.pth')
|
601 |
+
optim_scheduler.step()
|
602 |
+
|
603 |
+
#### get different uncertainty maps
|
604 |
+
def get_uncer_BayesCap(
|
605 |
+
NetC,
|
606 |
+
NetG,
|
607 |
+
xin,
|
608 |
+
task=None,
|
609 |
+
xMask=None,
|
610 |
+
):
|
611 |
+
with torch.no_grad():
|
612 |
+
if task == 'inpainting':
|
613 |
+
_, xSR = NetG(xin, xMask)
|
614 |
+
else:
|
615 |
+
xSR = NetG(xin)
|
616 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
617 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
618 |
+
b_map = xSRC_beta.to('cpu').data
|
619 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
620 |
+
|
621 |
+
return xSRvar
|
622 |
+
|
623 |
+
def get_uncer_TTDAp(
|
624 |
+
NetG,
|
625 |
+
xin,
|
626 |
+
p_mag=0.05,
|
627 |
+
num_runs=50,
|
628 |
+
task=None,
|
629 |
+
xMask=None,
|
630 |
+
):
|
631 |
+
list_xSR = []
|
632 |
+
with torch.no_grad():
|
633 |
+
for z in range(num_runs):
|
634 |
+
if task == 'inpainting':
|
635 |
+
_, xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin), xMask)
|
636 |
+
else:
|
637 |
+
xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin))
|
638 |
+
list_xSR.append(xSRz)
|
639 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
640 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
641 |
+
return xSRvar
|
642 |
+
|
643 |
+
def get_uncer_DO(
|
644 |
+
NetG,
|
645 |
+
xin,
|
646 |
+
dop=0.2,
|
647 |
+
num_runs=50,
|
648 |
+
task=None,
|
649 |
+
xMask=None,
|
650 |
+
):
|
651 |
+
list_xSR = []
|
652 |
+
with torch.no_grad():
|
653 |
+
for z in range(num_runs):
|
654 |
+
if task == 'inpainting':
|
655 |
+
_, xSRz = NetG(xin, xMask, dop=dop)
|
656 |
+
else:
|
657 |
+
xSRz = NetG(xin, dop=dop)
|
658 |
+
list_xSR.append(xSRz)
|
659 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
660 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
661 |
+
return xSRvar
|
662 |
+
|
663 |
+
################### Different eval functions
|
664 |
+
|
665 |
+
def eval_BayesCap(
|
666 |
+
NetC,
|
667 |
+
NetG,
|
668 |
+
eval_loader,
|
669 |
+
device='cuda',
|
670 |
+
dtype=torch.cuda.FloatTensor,
|
671 |
+
task=None,
|
672 |
+
xMask=None,
|
673 |
+
):
|
674 |
+
NetC.to(device)
|
675 |
+
NetC.eval()
|
676 |
+
NetG.to(device)
|
677 |
+
NetG.eval()
|
678 |
+
|
679 |
+
mean_ssim = 0
|
680 |
+
mean_psnr = 0
|
681 |
+
mean_mse = 0
|
682 |
+
mean_mae = 0
|
683 |
+
num_imgs = 0
|
684 |
+
list_error = []
|
685 |
+
list_var = []
|
686 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
687 |
+
for (idx, batch) in enumerate(tepoch):
|
688 |
+
tepoch.set_description('Validating ...')
|
689 |
+
##
|
690 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
691 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
692 |
+
if task == 'inpainting':
|
693 |
+
if xMask==None:
|
694 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
695 |
+
xMask = xMask.to(device).type(dtype)
|
696 |
+
else:
|
697 |
+
xMask = xMask.to(device).type(dtype)
|
698 |
+
# pass them through the network
|
699 |
+
with torch.no_grad():
|
700 |
+
if task == 'inpainting':
|
701 |
+
_, xSR = NetG(xLR, xMask)
|
702 |
+
elif task == 'depth':
|
703 |
+
xSR = NetG(xLR)[("disp", 0)]
|
704 |
+
else:
|
705 |
+
xSR = NetG(xLR)
|
706 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
707 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
708 |
+
b_map = xSRC_beta.to('cpu').data
|
709 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
710 |
+
n_batch = xSRC_mu.shape[0]
|
711 |
+
if task == 'depth':
|
712 |
+
xHR = xSR
|
713 |
+
for j in range(n_batch):
|
714 |
+
num_imgs += 1
|
715 |
+
mean_ssim += img_ssim(xSRC_mu[j], xHR[j])
|
716 |
+
mean_psnr += img_psnr(xSRC_mu[j], xHR[j])
|
717 |
+
mean_mse += img_mse(xSRC_mu[j], xHR[j])
|
718 |
+
mean_mae += img_mae(xSRC_mu[j], xHR[j])
|
719 |
+
|
720 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
721 |
+
|
722 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
723 |
+
var_map = xSRvar[j].to('cpu').data.reshape(-1)
|
724 |
+
list_error.extend(list(error_map.numpy()))
|
725 |
+
list_var.extend(list(var_map.numpy()))
|
726 |
+
##
|
727 |
+
mean_ssim /= num_imgs
|
728 |
+
mean_psnr /= num_imgs
|
729 |
+
mean_mse /= num_imgs
|
730 |
+
mean_mae /= num_imgs
|
731 |
+
print(
|
732 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
733 |
+
(
|
734 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
735 |
+
)
|
736 |
+
)
|
737 |
+
# print(len(list_error), len(list_var))
|
738 |
+
# print('UCE: ', get_UCE(list_error[::10], list_var[::10], num_bins=500)[1])
|
739 |
+
# print('C.Coeff: ', np.corrcoef(np.array(list_error[::10]), np.array(list_var[::10])))
|
740 |
+
return mean_ssim
|
741 |
+
|
742 |
+
def eval_TTDA_p(
|
743 |
+
NetG,
|
744 |
+
eval_loader,
|
745 |
+
device='cuda',
|
746 |
+
dtype=torch.cuda.FloatTensor,
|
747 |
+
p_mag=0.05,
|
748 |
+
num_runs=50,
|
749 |
+
task = None,
|
750 |
+
xMask = None,
|
751 |
+
):
|
752 |
+
NetG.to(device)
|
753 |
+
NetG.eval()
|
754 |
+
|
755 |
+
mean_ssim = 0
|
756 |
+
mean_psnr = 0
|
757 |
+
mean_mse = 0
|
758 |
+
mean_mae = 0
|
759 |
+
num_imgs = 0
|
760 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
761 |
+
for (idx, batch) in enumerate(tepoch):
|
762 |
+
tepoch.set_description('Validating ...')
|
763 |
+
##
|
764 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
765 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
766 |
+
# pass them through the network
|
767 |
+
list_xSR = []
|
768 |
+
with torch.no_grad():
|
769 |
+
if task=='inpainting':
|
770 |
+
_, xSR = NetG(xLR, xMask)
|
771 |
+
else:
|
772 |
+
xSR = NetG(xLR)
|
773 |
+
for z in range(num_runs):
|
774 |
+
xSRz = NetG(xLR+p_mag*xLR.max()*torch.randn_like(xLR))
|
775 |
+
list_xSR.append(xSRz)
|
776 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
777 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
778 |
+
n_batch = xSR.shape[0]
|
779 |
+
for j in range(n_batch):
|
780 |
+
num_imgs += 1
|
781 |
+
mean_ssim += img_ssim(xSR[j], xHR[j])
|
782 |
+
mean_psnr += img_psnr(xSR[j], xHR[j])
|
783 |
+
mean_mse += img_mse(xSR[j], xHR[j])
|
784 |
+
mean_mae += img_mae(xSR[j], xHR[j])
|
785 |
+
|
786 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
787 |
+
|
788 |
+
mean_ssim /= num_imgs
|
789 |
+
mean_psnr /= num_imgs
|
790 |
+
mean_mse /= num_imgs
|
791 |
+
mean_mae /= num_imgs
|
792 |
+
print(
|
793 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
794 |
+
(
|
795 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
796 |
+
)
|
797 |
+
)
|
798 |
+
|
799 |
+
return mean_ssim
|
800 |
+
|
801 |
+
def eval_DO(
|
802 |
+
NetG,
|
803 |
+
eval_loader,
|
804 |
+
device='cuda',
|
805 |
+
dtype=torch.cuda.FloatTensor,
|
806 |
+
dop=0.2,
|
807 |
+
num_runs=50,
|
808 |
+
task=None,
|
809 |
+
xMask=None,
|
810 |
+
):
|
811 |
+
NetG.to(device)
|
812 |
+
NetG.eval()
|
813 |
+
|
814 |
+
mean_ssim = 0
|
815 |
+
mean_psnr = 0
|
816 |
+
mean_mse = 0
|
817 |
+
mean_mae = 0
|
818 |
+
num_imgs = 0
|
819 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
820 |
+
for (idx, batch) in enumerate(tepoch):
|
821 |
+
tepoch.set_description('Validating ...')
|
822 |
+
##
|
823 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
824 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
825 |
+
# pass them through the network
|
826 |
+
list_xSR = []
|
827 |
+
with torch.no_grad():
|
828 |
+
if task == 'inpainting':
|
829 |
+
_, xSR = NetG(xLR, xMask)
|
830 |
+
else:
|
831 |
+
xSR = NetG(xLR)
|
832 |
+
for z in range(num_runs):
|
833 |
+
xSRz = NetG(xLR, dop=dop)
|
834 |
+
list_xSR.append(xSRz)
|
835 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
836 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
837 |
+
n_batch = xSR.shape[0]
|
838 |
+
for j in range(n_batch):
|
839 |
+
num_imgs += 1
|
840 |
+
mean_ssim += img_ssim(xSR[j], xHR[j])
|
841 |
+
mean_psnr += img_psnr(xSR[j], xHR[j])
|
842 |
+
mean_mse += img_mse(xSR[j], xHR[j])
|
843 |
+
mean_mae += img_mae(xSR[j], xHR[j])
|
844 |
+
|
845 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
846 |
+
##
|
847 |
+
mean_ssim /= num_imgs
|
848 |
+
mean_psnr /= num_imgs
|
849 |
+
mean_mse /= num_imgs
|
850 |
+
mean_mae /= num_imgs
|
851 |
+
print(
|
852 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
853 |
+
(
|
854 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
855 |
+
)
|
856 |
+
)
|
857 |
+
|
858 |
+
return mean_ssim
|
859 |
+
|
860 |
+
|
861 |
+
############### compare all function
|
862 |
+
def compare_all(
|
863 |
+
NetC,
|
864 |
+
NetG,
|
865 |
+
eval_loader,
|
866 |
+
p_mag = 0.05,
|
867 |
+
dop = 0.2,
|
868 |
+
num_runs = 100,
|
869 |
+
device='cuda',
|
870 |
+
dtype=torch.cuda.FloatTensor,
|
871 |
+
task=None,
|
872 |
+
):
|
873 |
+
NetC.to(device)
|
874 |
+
NetC.eval()
|
875 |
+
NetG.to(device)
|
876 |
+
NetG.eval()
|
877 |
+
|
878 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
879 |
+
for (idx, batch) in enumerate(tepoch):
|
880 |
+
tepoch.set_description('Comparing ...')
|
881 |
+
##
|
882 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
883 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
884 |
+
if task == 'inpainting':
|
885 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
886 |
+
xMask = xMask.to(device).type(dtype)
|
887 |
+
# pass them through the network
|
888 |
+
with torch.no_grad():
|
889 |
+
if task == 'inpainting':
|
890 |
+
_, xSR = NetG(xLR, xMask)
|
891 |
+
else:
|
892 |
+
xSR = NetG(xLR)
|
893 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
894 |
+
|
895 |
+
if task == 'inpainting':
|
896 |
+
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs, task='inpainting', xMask=xMask)
|
897 |
+
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs, task='inpainting', xMask=xMask)
|
898 |
+
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR, task='inpainting', xMask=xMask)
|
899 |
+
else:
|
900 |
+
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs)
|
901 |
+
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs)
|
902 |
+
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR)
|
903 |
+
|
904 |
+
print('bdg', xSRvar1.shape, xSRvar2.shape, xSRvar3.shape)
|
905 |
+
|
906 |
+
n_batch = xSR.shape[0]
|
907 |
+
for j in range(n_batch):
|
908 |
+
if task=='s':
|
909 |
+
show_SR_w_err(xLR[j], xHR[j], xSR[j])
|
910 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
|
911 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j])
|
912 |
+
if task=='d':
|
913 |
+
show_SR_w_err(xLR[j], xHR[j], 0.5*xSR[j]+0.5*xHR[j])
|
914 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
|
915 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
|
916 |
+
if task=='inpainting':
|
917 |
+
show_SR_w_err(xLR[j]*(1-xMask[j]), xHR[j], xSR[j], elim=(0,0.25), task='inpainting', xMask=xMask[j])
|
918 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.45), torch.pow(xSRvar1[j], 0.4))
|
919 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
|
920 |
+
if task=='m':
|
921 |
+
show_SR_w_err(xLR[j], xHR[j], xSR[j], elim=(0,0.04), task='m')
|
922 |
+
show_uncer4(0.4*xSRvar1[j]+0.6*xSRvar2[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42), ulim=(0.02,0.15))
|
923 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j], ulim=(0.02,0.15))
|
924 |
+
|
925 |
+
|
926 |
+
################# Degrading Identity
|
927 |
+
def degrage_BayesCap_p(
|
928 |
+
NetC,
|
929 |
+
NetG,
|
930 |
+
eval_loader,
|
931 |
+
device='cuda',
|
932 |
+
dtype=torch.cuda.FloatTensor,
|
933 |
+
num_runs=50,
|
934 |
+
):
|
935 |
+
NetC.to(device)
|
936 |
+
NetC.eval()
|
937 |
+
NetG.to(device)
|
938 |
+
NetG.eval()
|
939 |
+
|
940 |
+
p_mag_list = [0, 0.05, 0.1, 0.15, 0.2]
|
941 |
+
list_s = []
|
942 |
+
list_p = []
|
943 |
+
list_u1 = []
|
944 |
+
list_u2 = []
|
945 |
+
list_c = []
|
946 |
+
for p_mag in p_mag_list:
|
947 |
+
mean_ssim = 0
|
948 |
+
mean_psnr = 0
|
949 |
+
mean_mse = 0
|
950 |
+
mean_mae = 0
|
951 |
+
num_imgs = 0
|
952 |
+
list_error = []
|
953 |
+
list_error2 = []
|
954 |
+
list_var = []
|
955 |
+
|
956 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
957 |
+
for (idx, batch) in enumerate(tepoch):
|
958 |
+
tepoch.set_description('Validating ...')
|
959 |
+
##
|
960 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
961 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
962 |
+
# pass them through the network
|
963 |
+
with torch.no_grad():
|
964 |
+
xSR = NetG(xLR)
|
965 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR + p_mag*xSR.max()*torch.randn_like(xSR))
|
966 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
967 |
+
b_map = xSRC_beta.to('cpu').data
|
968 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
969 |
+
n_batch = xSRC_mu.shape[0]
|
970 |
+
for j in range(n_batch):
|
971 |
+
num_imgs += 1
|
972 |
+
mean_ssim += img_ssim(xSRC_mu[j], xSR[j])
|
973 |
+
mean_psnr += img_psnr(xSRC_mu[j], xSR[j])
|
974 |
+
mean_mse += img_mse(xSRC_mu[j], xSR[j])
|
975 |
+
mean_mae += img_mae(xSRC_mu[j], xSR[j])
|
976 |
+
|
977 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
978 |
+
error_map2 = torch.mean(torch.pow(torch.abs(xSRC_mu[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
979 |
+
var_map = xSRvar[j].to('cpu').data.reshape(-1)
|
980 |
+
list_error.extend(list(error_map.numpy()))
|
981 |
+
list_error2.extend(list(error_map2.numpy()))
|
982 |
+
list_var.extend(list(var_map.numpy()))
|
983 |
+
##
|
984 |
+
mean_ssim /= num_imgs
|
985 |
+
mean_psnr /= num_imgs
|
986 |
+
mean_mse /= num_imgs
|
987 |
+
mean_mae /= num_imgs
|
988 |
+
print(
|
989 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
990 |
+
(
|
991 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
992 |
+
)
|
993 |
+
)
|
994 |
+
uce1 = get_UCE(list_error[::100], list_var[::100], num_bins=200)[1]
|
995 |
+
uce2 = get_UCE(list_error2[::100], list_var[::100], num_bins=200)[1]
|
996 |
+
print('UCE1: ', uce1)
|
997 |
+
print('UCE2: ', uce2)
|
998 |
+
list_s.append(mean_ssim.item())
|
999 |
+
list_p.append(mean_psnr.item())
|
1000 |
+
list_u1.append(uce1)
|
1001 |
+
list_u2.append(uce2)
|
1002 |
+
|
1003 |
+
plt.plot(list_s)
|
1004 |
+
plt.show()
|
1005 |
+
plt.plot(list_p)
|
1006 |
+
plt.show()
|
1007 |
+
|
1008 |
+
plt.plot(list_u1, label='wrt SR output')
|
1009 |
+
plt.plot(list_u2, label='wrt BayesCap output')
|
1010 |
+
plt.legend()
|
1011 |
+
plt.show()
|
1012 |
+
|
1013 |
+
sns.set_style('darkgrid')
|
1014 |
+
fig,ax = plt.subplots()
|
1015 |
+
# make a plot
|
1016 |
+
ax.plot(p_mag_list, list_s, color="red", marker="o")
|
1017 |
+
# set x-axis label
|
1018 |
+
ax.set_xlabel("Reducing faithfulness of BayesCap Reconstruction",fontsize=10)
|
1019 |
+
# set y-axis label
|
1020 |
+
ax.set_ylabel("SSIM btwn BayesCap and SRGAN outputs", color="red",fontsize=10)
|
1021 |
+
|
1022 |
+
# twin object for two different y-axis on the sample plot
|
1023 |
+
ax2=ax.twinx()
|
1024 |
+
# make a plot with different y-axis using second axis object
|
1025 |
+
ax2.plot(p_mag_list, list_u1, color="blue", marker="o", label='UCE wrt to error btwn SRGAN output and GT')
|
1026 |
+
ax2.plot(p_mag_list, list_u2, color="orange", marker="o", label='UCE wrt to error btwn BayesCap output and GT')
|
1027 |
+
ax2.set_ylabel("UCE", color="green", fontsize=10)
|
1028 |
+
plt.legend(fontsize=10)
|
1029 |
+
plt.tight_layout()
|
1030 |
+
plt.show()
|
1031 |
+
|
1032 |
+
################# DeepFill_v2
|
1033 |
+
|
1034 |
+
# ----------------------------------------
|
1035 |
+
# PATH processing
|
1036 |
+
# ----------------------------------------
|
1037 |
+
def text_readlines(filename):
|
1038 |
+
# Try to read a txt file and return a list.Return [] if there was a mistake.
|
1039 |
+
try:
|
1040 |
+
file = open(filename, 'r')
|
1041 |
+
except IOError:
|
1042 |
+
error = []
|
1043 |
+
return error
|
1044 |
+
content = file.readlines()
|
1045 |
+
# This for loop deletes the EOF (like \n)
|
1046 |
+
for i in range(len(content)):
|
1047 |
+
content[i] = content[i][:len(content[i])-1]
|
1048 |
+
file.close()
|
1049 |
+
return content
|
1050 |
+
|
1051 |
+
def savetxt(name, loss_log):
|
1052 |
+
np_loss_log = np.array(loss_log)
|
1053 |
+
np.savetxt(name, np_loss_log)
|
1054 |
+
|
1055 |
+
def get_files(path):
|
1056 |
+
# read a folder, return the complete path
|
1057 |
+
ret = []
|
1058 |
+
for root, dirs, files in os.walk(path):
|
1059 |
+
for filespath in files:
|
1060 |
+
ret.append(os.path.join(root, filespath))
|
1061 |
+
return ret
|
1062 |
+
|
1063 |
+
def get_names(path):
|
1064 |
+
# read a folder, return the image name
|
1065 |
+
ret = []
|
1066 |
+
for root, dirs, files in os.walk(path):
|
1067 |
+
for filespath in files:
|
1068 |
+
ret.append(filespath)
|
1069 |
+
return ret
|
1070 |
+
|
1071 |
+
def text_save(content, filename, mode = 'a'):
|
1072 |
+
# save a list to a txt
|
1073 |
+
# Try to save a list variable in txt file.
|
1074 |
+
file = open(filename, mode)
|
1075 |
+
for i in range(len(content)):
|
1076 |
+
file.write(str(content[i]) + '\n')
|
1077 |
+
file.close()
|
1078 |
+
|
1079 |
+
def check_path(path):
|
1080 |
+
if not os.path.exists(path):
|
1081 |
+
os.makedirs(path)
|
1082 |
+
|
1083 |
+
# ----------------------------------------
|
1084 |
+
# Validation and Sample at training
|
1085 |
+
# ----------------------------------------
|
1086 |
+
def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255):
|
1087 |
+
# Save image one-by-one
|
1088 |
+
for i in range(len(img_list)):
|
1089 |
+
img = img_list[i]
|
1090 |
+
# Recover normalization: * 255 because last layer is sigmoid activated
|
1091 |
+
img = img * 255
|
1092 |
+
# Process img_copy and do not destroy the data of img
|
1093 |
+
img_copy = img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
|
1094 |
+
img_copy = np.clip(img_copy, 0, pixel_max_cnt)
|
1095 |
+
img_copy = img_copy.astype(np.uint8)
|
1096 |
+
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
|
1097 |
+
# Save to certain path
|
1098 |
+
save_img_name = sample_name + '_' + name_list[i] + '.jpg'
|
1099 |
+
save_img_path = os.path.join(sample_folder, save_img_name)
|
1100 |
+
cv2.imwrite(save_img_path, img_copy)
|
1101 |
+
|
1102 |
+
def psnr(pred, target, pixel_max_cnt = 255):
|
1103 |
+
mse = torch.mul(target - pred, target - pred)
|
1104 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
1105 |
+
p = 20 * np.log10(pixel_max_cnt / rmse_avg)
|
1106 |
+
return p
|
1107 |
+
|
1108 |
+
def grey_psnr(pred, target, pixel_max_cnt = 255):
|
1109 |
+
pred = torch.sum(pred, dim = 0)
|
1110 |
+
target = torch.sum(target, dim = 0)
|
1111 |
+
mse = torch.mul(target - pred, target - pred)
|
1112 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
1113 |
+
p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
|
1114 |
+
return p
|
1115 |
+
|
1116 |
+
def ssim(pred, target):
|
1117 |
+
pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy()
|
1118 |
+
target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy()
|
1119 |
+
target = target[0]
|
1120 |
+
pred = pred[0]
|
1121 |
+
ssim = skimage.measure.compare_ssim(target, pred, multichannel = True)
|
1122 |
+
return ssim
|
1123 |
+
|
1124 |
+
## for contextual attention
|
1125 |
+
|
1126 |
+
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
|
1127 |
+
"""
|
1128 |
+
Extract patches from images and put them in the C output dimension.
|
1129 |
+
:param padding:
|
1130 |
+
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
|
1131 |
+
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
|
1132 |
+
each dimension of images
|
1133 |
+
:param strides: [stride_rows, stride_cols]
|
1134 |
+
:param rates: [dilation_rows, dilation_cols]
|
1135 |
+
:return: A Tensor
|
1136 |
+
"""
|
1137 |
+
assert len(images.size()) == 4
|
1138 |
+
assert padding in ['same', 'valid']
|
1139 |
+
batch_size, channel, height, width = images.size()
|
1140 |
+
|
1141 |
+
if padding == 'same':
|
1142 |
+
images = same_padding(images, ksizes, strides, rates)
|
1143 |
+
elif padding == 'valid':
|
1144 |
+
pass
|
1145 |
+
else:
|
1146 |
+
raise NotImplementedError('Unsupported padding type: {}.\
|
1147 |
+
Only "same" or "valid" are supported.'.format(padding))
|
1148 |
+
|
1149 |
+
unfold = torch.nn.Unfold(kernel_size=ksizes,
|
1150 |
+
dilation=rates,
|
1151 |
+
padding=0,
|
1152 |
+
stride=strides)
|
1153 |
+
patches = unfold(images)
|
1154 |
+
return patches # [N, C*k*k, L], L is the total number of such blocks
|
1155 |
+
|
1156 |
+
def same_padding(images, ksizes, strides, rates):
|
1157 |
+
assert len(images.size()) == 4
|
1158 |
+
batch_size, channel, rows, cols = images.size()
|
1159 |
+
out_rows = (rows + strides[0] - 1) // strides[0]
|
1160 |
+
out_cols = (cols + strides[1] - 1) // strides[1]
|
1161 |
+
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
|
1162 |
+
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
|
1163 |
+
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
|
1164 |
+
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
|
1165 |
+
# Pad the input
|
1166 |
+
padding_top = int(padding_rows / 2.)
|
1167 |
+
padding_left = int(padding_cols / 2.)
|
1168 |
+
padding_bottom = padding_rows - padding_top
|
1169 |
+
padding_right = padding_cols - padding_left
|
1170 |
+
paddings = (padding_left, padding_right, padding_top, padding_bottom)
|
1171 |
+
images = torch.nn.ZeroPad2d(paddings)(images)
|
1172 |
+
return images
|
1173 |
+
|
1174 |
+
def reduce_mean(x, axis=None, keepdim=False):
|
1175 |
+
if not axis:
|
1176 |
+
axis = range(len(x.shape))
|
1177 |
+
for i in sorted(axis, reverse=True):
|
1178 |
+
x = torch.mean(x, dim=i, keepdim=keepdim)
|
1179 |
+
return x
|
1180 |
+
|
1181 |
+
|
1182 |
+
def reduce_std(x, axis=None, keepdim=False):
|
1183 |
+
if not axis:
|
1184 |
+
axis = range(len(x.shape))
|
1185 |
+
for i in sorted(axis, reverse=True):
|
1186 |
+
x = torch.std(x, dim=i, keepdim=keepdim)
|
1187 |
+
return x
|
1188 |
+
|
1189 |
+
|
1190 |
+
def reduce_sum(x, axis=None, keepdim=False):
|
1191 |
+
if not axis:
|
1192 |
+
axis = range(len(x.shape))
|
1193 |
+
for i in sorted(axis, reverse=True):
|
1194 |
+
x = torch.sum(x, dim=i, keepdim=keepdim)
|
1195 |
+
return x
|
1196 |
+
|
1197 |
+
def random_mask(num_batch=1, mask_shape=(256,256)):
|
1198 |
+
list_mask = []
|
1199 |
+
for _ in range(num_batch):
|
1200 |
+
# rectangle mask
|
1201 |
+
image_height = mask_shape[0]
|
1202 |
+
image_width = mask_shape[1]
|
1203 |
+
max_delta_height = image_height//8
|
1204 |
+
max_delta_width = image_width//8
|
1205 |
+
height = image_height//4
|
1206 |
+
width = image_width//4
|
1207 |
+
max_t = image_height - height
|
1208 |
+
max_l = image_width - width
|
1209 |
+
t = random.randint(0, max_t)
|
1210 |
+
l = random.randint(0, max_l)
|
1211 |
+
# bbox = (t, l, height, width)
|
1212 |
+
h = random.randint(0, max_delta_height//2)
|
1213 |
+
w = random.randint(0, max_delta_width//2)
|
1214 |
+
mask = torch.zeros((1, 1, image_height, image_width))
|
1215 |
+
mask[:, :, t+h:t+height-h, l+w:l+width-w] = 1
|
1216 |
+
rect_mask = mask
|
1217 |
+
|
1218 |
+
# brush mask
|
1219 |
+
min_num_vertex = 4
|
1220 |
+
max_num_vertex = 12
|
1221 |
+
mean_angle = 2 * math.pi / 5
|
1222 |
+
angle_range = 2 * math.pi / 15
|
1223 |
+
min_width = 12
|
1224 |
+
max_width = 40
|
1225 |
+
H, W = image_height, image_width
|
1226 |
+
average_radius = math.sqrt(H*H+W*W) / 8
|
1227 |
+
mask = Image.new('L', (W, H), 0)
|
1228 |
+
|
1229 |
+
for _ in range(np.random.randint(1, 4)):
|
1230 |
+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
1231 |
+
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
1232 |
+
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
1233 |
+
angles = []
|
1234 |
+
vertex = []
|
1235 |
+
for i in range(num_vertex):
|
1236 |
+
if i % 2 == 0:
|
1237 |
+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
1238 |
+
else:
|
1239 |
+
angles.append(np.random.uniform(angle_min, angle_max))
|
1240 |
+
|
1241 |
+
h, w = mask.size
|
1242 |
+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
1243 |
+
for i in range(num_vertex):
|
1244 |
+
r = np.clip(
|
1245 |
+
np.random.normal(loc=average_radius, scale=average_radius//2),
|
1246 |
+
0, 2*average_radius)
|
1247 |
+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
1248 |
+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
1249 |
+
vertex.append((int(new_x), int(new_y)))
|
1250 |
+
|
1251 |
+
draw = ImageDraw.Draw(mask)
|
1252 |
+
width = int(np.random.uniform(min_width, max_width))
|
1253 |
+
draw.line(vertex, fill=255, width=width)
|
1254 |
+
for v in vertex:
|
1255 |
+
draw.ellipse((v[0] - width//2,
|
1256 |
+
v[1] - width//2,
|
1257 |
+
v[0] + width//2,
|
1258 |
+
v[1] + width//2),
|
1259 |
+
fill=255)
|
1260 |
+
|
1261 |
+
if np.random.normal() > 0:
|
1262 |
+
mask.transpose(Image.FLIP_LEFT_RIGHT)
|
1263 |
+
if np.random.normal() > 0:
|
1264 |
+
mask.transpose(Image.FLIP_TOP_BOTTOM)
|
1265 |
+
|
1266 |
+
mask = transforms.ToTensor()(mask)
|
1267 |
+
mask = mask.reshape((1, 1, H, W))
|
1268 |
+
brush_mask = mask
|
1269 |
+
|
1270 |
+
mask = torch.cat([rect_mask, brush_mask], dim=1).max(dim=1, keepdim=True)[0]
|
1271 |
+
list_mask.append(mask)
|
1272 |
+
mask = torch.cat(list_mask, dim=0)
|
1273 |
+
return mask
|
utils.py
ADDED
@@ -0,0 +1,1304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Any, Optional
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import cv2
|
6 |
+
from glob import glob
|
7 |
+
from PIL import Image, ImageDraw
|
8 |
+
from tqdm import tqdm
|
9 |
+
import kornia
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import seaborn as sns
|
12 |
+
import albumentations as albu
|
13 |
+
import functools
|
14 |
+
import math
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from torch import Tensor
|
19 |
+
import torchvision as tv
|
20 |
+
import torchvision.models as models
|
21 |
+
from torchvision import transforms
|
22 |
+
from torchvision.transforms import functional as F
|
23 |
+
from losses import TempCombLoss
|
24 |
+
|
25 |
+
|
26 |
+
######## for loading checkpoint from googledrive
|
27 |
+
google_drive_paths = {
|
28 |
+
"BayesCap_SRGAN.pth": "https://drive.google.com/uc?id=1d_5j1f8-vN79htZTfRUqP1ddHZIYsNvL",
|
29 |
+
"BayesCap_ckpt.pth": "https://drive.google.com/uc?id=1Vg1r6gKgQ1J3M51n6BeKXYS8auT9NhA9",
|
30 |
+
}
|
31 |
+
|
32 |
+
def ensure_checkpoint_exists(model_weights_filename):
|
33 |
+
if not os.path.isfile(model_weights_filename) and (
|
34 |
+
model_weights_filename in google_drive_paths
|
35 |
+
):
|
36 |
+
gdrive_url = google_drive_paths[model_weights_filename]
|
37 |
+
try:
|
38 |
+
from gdown import download as drive_download
|
39 |
+
|
40 |
+
drive_download(gdrive_url, model_weights_filename, quiet=False)
|
41 |
+
except ModuleNotFoundError:
|
42 |
+
print(
|
43 |
+
"gdown module not found.",
|
44 |
+
"pip3 install gdown or, manually download the checkpoint file:",
|
45 |
+
gdrive_url
|
46 |
+
)
|
47 |
+
|
48 |
+
if not os.path.isfile(model_weights_filename) and (
|
49 |
+
model_weights_filename not in google_drive_paths
|
50 |
+
):
|
51 |
+
print(
|
52 |
+
model_weights_filename,
|
53 |
+
" not found, you may need to manually download the model weights."
|
54 |
+
)
|
55 |
+
|
56 |
+
########### DeblurGAN function
|
57 |
+
def get_norm_layer(norm_type='instance'):
|
58 |
+
if norm_type == 'batch':
|
59 |
+
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
60 |
+
elif norm_type == 'instance':
|
61 |
+
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=True)
|
62 |
+
else:
|
63 |
+
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
64 |
+
return norm_layer
|
65 |
+
|
66 |
+
def _array_to_batch(x):
|
67 |
+
x = np.transpose(x, (2, 0, 1))
|
68 |
+
x = np.expand_dims(x, 0)
|
69 |
+
return torch.from_numpy(x)
|
70 |
+
|
71 |
+
def get_normalize():
|
72 |
+
normalize = albu.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
73 |
+
normalize = albu.Compose([normalize], additional_targets={'target': 'image'})
|
74 |
+
|
75 |
+
def process(a, b):
|
76 |
+
r = normalize(image=a, target=b)
|
77 |
+
return r['image'], r['target']
|
78 |
+
|
79 |
+
return process
|
80 |
+
|
81 |
+
def preprocess(x: np.ndarray, mask: Optional[np.ndarray]):
|
82 |
+
x, _ = get_normalize()(x, x)
|
83 |
+
if mask is None:
|
84 |
+
mask = np.ones_like(x, dtype=np.float32)
|
85 |
+
else:
|
86 |
+
mask = np.round(mask.astype('float32') / 255)
|
87 |
+
|
88 |
+
h, w, _ = x.shape
|
89 |
+
block_size = 32
|
90 |
+
min_height = (h // block_size + 1) * block_size
|
91 |
+
min_width = (w // block_size + 1) * block_size
|
92 |
+
|
93 |
+
pad_params = {'mode': 'constant',
|
94 |
+
'constant_values': 0,
|
95 |
+
'pad_width': ((0, min_height - h), (0, min_width - w), (0, 0))
|
96 |
+
}
|
97 |
+
x = np.pad(x, **pad_params)
|
98 |
+
mask = np.pad(mask, **pad_params)
|
99 |
+
|
100 |
+
return map(_array_to_batch, (x, mask)), h, w
|
101 |
+
|
102 |
+
def postprocess(x: torch.Tensor) -> np.ndarray:
|
103 |
+
x, = x
|
104 |
+
x = x.detach().cpu().float().numpy()
|
105 |
+
x = (np.transpose(x, (1, 2, 0)) + 1) / 2.0 * 255.0
|
106 |
+
return x.astype('uint8')
|
107 |
+
|
108 |
+
def sorted_glob(pattern):
|
109 |
+
return sorted(glob(pattern))
|
110 |
+
###########
|
111 |
+
|
112 |
+
def normalize(image: np.ndarray) -> np.ndarray:
|
113 |
+
"""Normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
114 |
+
Args:
|
115 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
116 |
+
Returns:
|
117 |
+
Normalized image data. Data range [0, 1].
|
118 |
+
"""
|
119 |
+
return image.astype(np.float64) / 255.0
|
120 |
+
|
121 |
+
|
122 |
+
def unnormalize(image: np.ndarray) -> np.ndarray:
|
123 |
+
"""Un-normalize the ``OpenCV.imread`` or ``skimage.io.imread`` data.
|
124 |
+
Args:
|
125 |
+
image (np.ndarray): The image data read by ``OpenCV.imread`` or ``skimage.io.imread``.
|
126 |
+
Returns:
|
127 |
+
Denormalized image data. Data range [0, 255].
|
128 |
+
"""
|
129 |
+
return image.astype(np.float64) * 255.0
|
130 |
+
|
131 |
+
|
132 |
+
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
|
133 |
+
"""Convert ``PIL.Image`` to Tensor.
|
134 |
+
Args:
|
135 |
+
image (np.ndarray): The image data read by ``PIL.Image``
|
136 |
+
range_norm (bool): Scale [0, 1] data to between [-1, 1]
|
137 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
138 |
+
Returns:
|
139 |
+
Normalized image data
|
140 |
+
Examples:
|
141 |
+
>>> image = Image.open("image.bmp")
|
142 |
+
>>> tensor_image = image2tensor(image, range_norm=False, half=False)
|
143 |
+
"""
|
144 |
+
tensor = F.to_tensor(image)
|
145 |
+
|
146 |
+
if range_norm:
|
147 |
+
tensor = tensor.mul_(2.0).sub_(1.0)
|
148 |
+
if half:
|
149 |
+
tensor = tensor.half()
|
150 |
+
|
151 |
+
return tensor
|
152 |
+
|
153 |
+
|
154 |
+
def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
|
155 |
+
"""Converts ``torch.Tensor`` to ``PIL.Image``.
|
156 |
+
Args:
|
157 |
+
tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
|
158 |
+
range_norm (bool): Scale [-1, 1] data to between [0, 1]
|
159 |
+
half (bool): Whether to convert torch.float32 similarly to torch.half type.
|
160 |
+
Returns:
|
161 |
+
Convert image data to support PIL library
|
162 |
+
Examples:
|
163 |
+
>>> tensor = torch.randn([1, 3, 128, 128])
|
164 |
+
>>> image = tensor2image(tensor, range_norm=False, half=False)
|
165 |
+
"""
|
166 |
+
if range_norm:
|
167 |
+
tensor = tensor.add_(1.0).div_(2.0)
|
168 |
+
if half:
|
169 |
+
tensor = tensor.half()
|
170 |
+
|
171 |
+
image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
|
172 |
+
|
173 |
+
return image
|
174 |
+
|
175 |
+
|
176 |
+
def convert_rgb_to_y(image: Any) -> Any:
|
177 |
+
"""Convert RGB image or tensor image data to YCbCr(Y) format.
|
178 |
+
Args:
|
179 |
+
image: RGB image data read by ``PIL.Image''.
|
180 |
+
Returns:
|
181 |
+
Y image array data.
|
182 |
+
"""
|
183 |
+
if type(image) == np.ndarray:
|
184 |
+
return 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
|
185 |
+
elif type(image) == torch.Tensor:
|
186 |
+
if len(image.shape) == 4:
|
187 |
+
image = image.squeeze_(0)
|
188 |
+
return 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
|
189 |
+
else:
|
190 |
+
raise Exception("Unknown Type", type(image))
|
191 |
+
|
192 |
+
|
193 |
+
def convert_rgb_to_ycbcr(image: Any) -> Any:
|
194 |
+
"""Convert RGB image or tensor image data to YCbCr format.
|
195 |
+
Args:
|
196 |
+
image: RGB image data read by ``PIL.Image''.
|
197 |
+
Returns:
|
198 |
+
YCbCr image array data.
|
199 |
+
"""
|
200 |
+
if type(image) == np.ndarray:
|
201 |
+
y = 16. + (64.738 * image[:, :, 0] + 129.057 * image[:, :, 1] + 25.064 * image[:, :, 2]) / 256.
|
202 |
+
cb = 128. + (-37.945 * image[:, :, 0] - 74.494 * image[:, :, 1] + 112.439 * image[:, :, 2]) / 256.
|
203 |
+
cr = 128. + (112.439 * image[:, :, 0] - 94.154 * image[:, :, 1] - 18.285 * image[:, :, 2]) / 256.
|
204 |
+
return np.array([y, cb, cr]).transpose([1, 2, 0])
|
205 |
+
elif type(image) == torch.Tensor:
|
206 |
+
if len(image.shape) == 4:
|
207 |
+
image = image.squeeze(0)
|
208 |
+
y = 16. + (64.738 * image[0, :, :] + 129.057 * image[1, :, :] + 25.064 * image[2, :, :]) / 256.
|
209 |
+
cb = 128. + (-37.945 * image[0, :, :] - 74.494 * image[1, :, :] + 112.439 * image[2, :, :]) / 256.
|
210 |
+
cr = 128. + (112.439 * image[0, :, :] - 94.154 * image[1, :, :] - 18.285 * image[2, :, :]) / 256.
|
211 |
+
return torch.cat([y, cb, cr], 0).permute(1, 2, 0)
|
212 |
+
else:
|
213 |
+
raise Exception("Unknown Type", type(image))
|
214 |
+
|
215 |
+
|
216 |
+
def convert_ycbcr_to_rgb(image: Any) -> Any:
|
217 |
+
"""Convert YCbCr format image to RGB format.
|
218 |
+
Args:
|
219 |
+
image: YCbCr image data read by ``PIL.Image''.
|
220 |
+
Returns:
|
221 |
+
RGB image array data.
|
222 |
+
"""
|
223 |
+
if type(image) == np.ndarray:
|
224 |
+
r = 298.082 * image[:, :, 0] / 256. + 408.583 * image[:, :, 2] / 256. - 222.921
|
225 |
+
g = 298.082 * image[:, :, 0] / 256. - 100.291 * image[:, :, 1] / 256. - 208.120 * image[:, :, 2] / 256. + 135.576
|
226 |
+
b = 298.082 * image[:, :, 0] / 256. + 516.412 * image[:, :, 1] / 256. - 276.836
|
227 |
+
return np.array([r, g, b]).transpose([1, 2, 0])
|
228 |
+
elif type(image) == torch.Tensor:
|
229 |
+
if len(image.shape) == 4:
|
230 |
+
image = image.squeeze(0)
|
231 |
+
r = 298.082 * image[0, :, :] / 256. + 408.583 * image[2, :, :] / 256. - 222.921
|
232 |
+
g = 298.082 * image[0, :, :] / 256. - 100.291 * image[1, :, :] / 256. - 208.120 * image[2, :, :] / 256. + 135.576
|
233 |
+
b = 298.082 * image[0, :, :] / 256. + 516.412 * image[1, :, :] / 256. - 276.836
|
234 |
+
return torch.cat([r, g, b], 0).permute(1, 2, 0)
|
235 |
+
else:
|
236 |
+
raise Exception("Unknown Type", type(image))
|
237 |
+
|
238 |
+
|
239 |
+
def center_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
|
240 |
+
"""Cut ``PIL.Image`` in the center area of the image.
|
241 |
+
Args:
|
242 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
243 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
244 |
+
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
|
245 |
+
upscale_factor (int): magnification factor.
|
246 |
+
Returns:
|
247 |
+
Randomly cropped low-resolution images and high-resolution images.
|
248 |
+
"""
|
249 |
+
w, h = hr.size
|
250 |
+
|
251 |
+
left = (w - image_size) // 2
|
252 |
+
top = (h - image_size) // 2
|
253 |
+
right = left + image_size
|
254 |
+
bottom = top + image_size
|
255 |
+
|
256 |
+
lr = lr.crop((left // upscale_factor,
|
257 |
+
top // upscale_factor,
|
258 |
+
right // upscale_factor,
|
259 |
+
bottom // upscale_factor))
|
260 |
+
hr = hr.crop((left, top, right, bottom))
|
261 |
+
|
262 |
+
return lr, hr
|
263 |
+
|
264 |
+
|
265 |
+
def random_crop(lr: Any, hr: Any, image_size: int, upscale_factor: int) -> [Any, Any]:
|
266 |
+
"""Will ``PIL.Image`` randomly capture the specified area of the image.
|
267 |
+
Args:
|
268 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
269 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
270 |
+
image_size (int): The size of the captured image area. It should be the size of the high-resolution image.
|
271 |
+
upscale_factor (int): magnification factor.
|
272 |
+
Returns:
|
273 |
+
Randomly cropped low-resolution images and high-resolution images.
|
274 |
+
"""
|
275 |
+
w, h = hr.size
|
276 |
+
left = torch.randint(0, w - image_size + 1, size=(1,)).item()
|
277 |
+
top = torch.randint(0, h - image_size + 1, size=(1,)).item()
|
278 |
+
right = left + image_size
|
279 |
+
bottom = top + image_size
|
280 |
+
|
281 |
+
lr = lr.crop((left // upscale_factor,
|
282 |
+
top // upscale_factor,
|
283 |
+
right // upscale_factor,
|
284 |
+
bottom // upscale_factor))
|
285 |
+
hr = hr.crop((left, top, right, bottom))
|
286 |
+
|
287 |
+
return lr, hr
|
288 |
+
|
289 |
+
|
290 |
+
def random_rotate(lr: Any, hr: Any, angle: int) -> [Any, Any]:
|
291 |
+
"""Will ``PIL.Image`` randomly rotate the image.
|
292 |
+
Args:
|
293 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
294 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
295 |
+
angle (int): rotation angle, clockwise and counterclockwise rotation.
|
296 |
+
Returns:
|
297 |
+
Randomly rotated low-resolution images and high-resolution images.
|
298 |
+
"""
|
299 |
+
angle = random.choice((+angle, -angle))
|
300 |
+
lr = F.rotate(lr, angle)
|
301 |
+
hr = F.rotate(hr, angle)
|
302 |
+
|
303 |
+
return lr, hr
|
304 |
+
|
305 |
+
|
306 |
+
def random_horizontally_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
|
307 |
+
"""Flip the ``PIL.Image`` image horizontally randomly.
|
308 |
+
Args:
|
309 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
310 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
311 |
+
p (optional, float): rollover probability. (Default: 0.5)
|
312 |
+
Returns:
|
313 |
+
Low-resolution image and high-resolution image after random horizontal flip.
|
314 |
+
"""
|
315 |
+
if torch.rand(1).item() > p:
|
316 |
+
lr = F.hflip(lr)
|
317 |
+
hr = F.hflip(hr)
|
318 |
+
|
319 |
+
return lr, hr
|
320 |
+
|
321 |
+
|
322 |
+
def random_vertically_flip(lr: Any, hr: Any, p=0.5) -> [Any, Any]:
|
323 |
+
"""Turn the ``PIL.Image`` image upside down randomly.
|
324 |
+
Args:
|
325 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
326 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
327 |
+
p (optional, float): rollover probability. (Default: 0.5)
|
328 |
+
Returns:
|
329 |
+
Randomly rotated up and down low-resolution images and high-resolution images.
|
330 |
+
"""
|
331 |
+
if torch.rand(1).item() > p:
|
332 |
+
lr = F.vflip(lr)
|
333 |
+
hr = F.vflip(hr)
|
334 |
+
|
335 |
+
return lr, hr
|
336 |
+
|
337 |
+
|
338 |
+
def random_adjust_brightness(lr: Any, hr: Any) -> [Any, Any]:
|
339 |
+
"""Set ``PIL.Image`` to randomly adjust the image brightness.
|
340 |
+
Args:
|
341 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
342 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
343 |
+
Returns:
|
344 |
+
Low-resolution image and high-resolution image with randomly adjusted brightness.
|
345 |
+
"""
|
346 |
+
# Randomly adjust the brightness gain range.
|
347 |
+
factor = random.uniform(0.5, 2)
|
348 |
+
lr = F.adjust_brightness(lr, factor)
|
349 |
+
hr = F.adjust_brightness(hr, factor)
|
350 |
+
|
351 |
+
return lr, hr
|
352 |
+
|
353 |
+
|
354 |
+
def random_adjust_contrast(lr: Any, hr: Any) -> [Any, Any]:
|
355 |
+
"""Set ``PIL.Image`` to randomly adjust the image contrast.
|
356 |
+
Args:
|
357 |
+
lr: Low-resolution image data read by ``PIL.Image``.
|
358 |
+
hr: High-resolution image data read by ``PIL.Image``.
|
359 |
+
Returns:
|
360 |
+
Low-resolution image and high-resolution image with randomly adjusted contrast.
|
361 |
+
"""
|
362 |
+
# Randomly adjust the contrast gain range.
|
363 |
+
factor = random.uniform(0.5, 2)
|
364 |
+
lr = F.adjust_contrast(lr, factor)
|
365 |
+
hr = F.adjust_contrast(hr, factor)
|
366 |
+
|
367 |
+
return lr, hr
|
368 |
+
|
369 |
+
#### metrics to compute -- assumes single images, i.e., tensor of 3 dims
|
370 |
+
def img_mae(x1, x2):
|
371 |
+
m = torch.abs(x1-x2).mean()
|
372 |
+
return m
|
373 |
+
|
374 |
+
def img_mse(x1, x2):
|
375 |
+
m = torch.pow(torch.abs(x1-x2),2).mean()
|
376 |
+
return m
|
377 |
+
|
378 |
+
def img_psnr(x1, x2):
|
379 |
+
m = kornia.metrics.psnr(x1, x2, 1)
|
380 |
+
return m
|
381 |
+
|
382 |
+
def img_ssim(x1, x2):
|
383 |
+
m = kornia.metrics.ssim(x1.unsqueeze(0), x2.unsqueeze(0), 5)
|
384 |
+
m = m.mean()
|
385 |
+
return m
|
386 |
+
|
387 |
+
def show_SR_w_uncer(xLR, xHR, xSR, xSRvar, elim=(0,0.01), ulim=(0,0.15)):
|
388 |
+
'''
|
389 |
+
xLR/SR/HR: 3xHxW
|
390 |
+
xSRvar: 1xHxW
|
391 |
+
'''
|
392 |
+
plt.figure(figsize=(30,10))
|
393 |
+
|
394 |
+
plt.subplot(1,5,1)
|
395 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
396 |
+
plt.axis('off')
|
397 |
+
|
398 |
+
plt.subplot(1,5,2)
|
399 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
400 |
+
plt.axis('off')
|
401 |
+
|
402 |
+
plt.subplot(1,5,3)
|
403 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
404 |
+
plt.axis('off')
|
405 |
+
|
406 |
+
plt.subplot(1,5,4)
|
407 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
|
408 |
+
print('error', error_map.min(), error_map.max())
|
409 |
+
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
|
410 |
+
plt.clim(elim[0], elim[1])
|
411 |
+
plt.axis('off')
|
412 |
+
|
413 |
+
plt.subplot(1,5,5)
|
414 |
+
print('uncer', xSRvar.min(), xSRvar.max())
|
415 |
+
plt.imshow(xSRvar.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
416 |
+
plt.clim(ulim[0], ulim[1])
|
417 |
+
plt.axis('off')
|
418 |
+
|
419 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
420 |
+
plt.show()
|
421 |
+
|
422 |
+
def show_SR_w_err(xLR, xHR, xSR, elim=(0,0.01), task=None, xMask=None):
|
423 |
+
'''
|
424 |
+
xLR/SR/HR: 3xHxW
|
425 |
+
'''
|
426 |
+
plt.figure(figsize=(30,10))
|
427 |
+
|
428 |
+
if task != 'm':
|
429 |
+
plt.subplot(1,4,1)
|
430 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
431 |
+
plt.axis('off')
|
432 |
+
|
433 |
+
plt.subplot(1,4,2)
|
434 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
435 |
+
plt.axis('off')
|
436 |
+
|
437 |
+
plt.subplot(1,4,3)
|
438 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1))
|
439 |
+
plt.axis('off')
|
440 |
+
else:
|
441 |
+
plt.subplot(1,4,1)
|
442 |
+
plt.imshow(xLR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
443 |
+
plt.clim(0,0.9)
|
444 |
+
plt.axis('off')
|
445 |
+
|
446 |
+
plt.subplot(1,4,2)
|
447 |
+
plt.imshow(xHR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
448 |
+
plt.clim(0,0.9)
|
449 |
+
plt.axis('off')
|
450 |
+
|
451 |
+
plt.subplot(1,4,3)
|
452 |
+
plt.imshow(xSR.to('cpu').data.clip(0,1).transpose(0,2).transpose(0,1), cmap='gray')
|
453 |
+
plt.clim(0,0.9)
|
454 |
+
plt.axis('off')
|
455 |
+
|
456 |
+
plt.subplot(1,4,4)
|
457 |
+
if task == 'inpainting':
|
458 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)*xMask.to('cpu').data
|
459 |
+
else:
|
460 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR-xHR),2), dim=0).to('cpu').data.unsqueeze(0)
|
461 |
+
print('error', error_map.min(), error_map.max())
|
462 |
+
plt.imshow(error_map.transpose(0,2).transpose(0,1), cmap='jet')
|
463 |
+
plt.clim(elim[0], elim[1])
|
464 |
+
plt.axis('off')
|
465 |
+
|
466 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
467 |
+
plt.show()
|
468 |
+
|
469 |
+
def show_uncer4(xSRvar1, xSRvar2, xSRvar3, xSRvar4, ulim=(0,0.15)):
|
470 |
+
'''
|
471 |
+
xSRvar: 1xHxW
|
472 |
+
'''
|
473 |
+
plt.figure(figsize=(30,10))
|
474 |
+
|
475 |
+
plt.subplot(1,4,1)
|
476 |
+
print('uncer', xSRvar1.min(), xSRvar1.max())
|
477 |
+
plt.imshow(xSRvar1.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
478 |
+
plt.clim(ulim[0], ulim[1])
|
479 |
+
plt.axis('off')
|
480 |
+
|
481 |
+
plt.subplot(1,4,2)
|
482 |
+
print('uncer', xSRvar2.min(), xSRvar2.max())
|
483 |
+
plt.imshow(xSRvar2.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
484 |
+
plt.clim(ulim[0], ulim[1])
|
485 |
+
plt.axis('off')
|
486 |
+
|
487 |
+
plt.subplot(1,4,3)
|
488 |
+
print('uncer', xSRvar3.min(), xSRvar3.max())
|
489 |
+
plt.imshow(xSRvar3.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
490 |
+
plt.clim(ulim[0], ulim[1])
|
491 |
+
plt.axis('off')
|
492 |
+
|
493 |
+
plt.subplot(1,4,4)
|
494 |
+
print('uncer', xSRvar4.min(), xSRvar4.max())
|
495 |
+
plt.imshow(xSRvar4.to('cpu').data.transpose(0,2).transpose(0,1), cmap='hot')
|
496 |
+
plt.clim(ulim[0], ulim[1])
|
497 |
+
plt.axis('off')
|
498 |
+
|
499 |
+
plt.subplots_adjust(wspace=0, hspace=0)
|
500 |
+
plt.show()
|
501 |
+
|
502 |
+
def get_UCE(list_err, list_yout_var, num_bins=100):
|
503 |
+
err_min = np.min(list_err)
|
504 |
+
err_max = np.max(list_err)
|
505 |
+
err_len = (err_max-err_min)/num_bins
|
506 |
+
num_points = len(list_err)
|
507 |
+
|
508 |
+
bin_stats = {}
|
509 |
+
for i in range(num_bins):
|
510 |
+
bin_stats[i] = {
|
511 |
+
'start_idx': err_min + i*err_len,
|
512 |
+
'end_idx': err_min + (i+1)*err_len,
|
513 |
+
'num_points': 0,
|
514 |
+
'mean_err': 0,
|
515 |
+
'mean_var': 0,
|
516 |
+
}
|
517 |
+
|
518 |
+
for e,v in zip(list_err, list_yout_var):
|
519 |
+
for i in range(num_bins):
|
520 |
+
if e>=bin_stats[i]['start_idx'] and e<bin_stats[i]['end_idx']:
|
521 |
+
bin_stats[i]['num_points'] += 1
|
522 |
+
bin_stats[i]['mean_err'] += e
|
523 |
+
bin_stats[i]['mean_var'] += v
|
524 |
+
|
525 |
+
uce = 0
|
526 |
+
eps = 1e-8
|
527 |
+
for i in range(num_bins):
|
528 |
+
bin_stats[i]['mean_err'] /= bin_stats[i]['num_points'] + eps
|
529 |
+
bin_stats[i]['mean_var'] /= bin_stats[i]['num_points'] + eps
|
530 |
+
bin_stats[i]['uce_bin'] = (bin_stats[i]['num_points']/num_points) \
|
531 |
+
*(np.abs(bin_stats[i]['mean_err'] - bin_stats[i]['mean_var']))
|
532 |
+
uce += bin_stats[i]['uce_bin']
|
533 |
+
|
534 |
+
list_x, list_y = [], []
|
535 |
+
for i in range(num_bins):
|
536 |
+
if bin_stats[i]['num_points']>0:
|
537 |
+
list_x.append(bin_stats[i]['mean_err'])
|
538 |
+
list_y.append(bin_stats[i]['mean_var'])
|
539 |
+
|
540 |
+
# sns.set_style('darkgrid')
|
541 |
+
# sns.scatterplot(x=list_x, y=list_y)
|
542 |
+
# sns.regplot(x=list_x, y=list_y, order=1)
|
543 |
+
# plt.xlabel('MSE', fontsize=34)
|
544 |
+
# plt.ylabel('Uncertainty', fontsize=34)
|
545 |
+
# plt.plot(list_x, list_x, color='r')
|
546 |
+
# plt.xlim(np.min(list_x), np.max(list_x))
|
547 |
+
# plt.ylim(np.min(list_err), np.max(list_x))
|
548 |
+
# plt.show()
|
549 |
+
|
550 |
+
return bin_stats, uce
|
551 |
+
|
552 |
+
##################### training BayesCap
|
553 |
+
def train_BayesCap(
|
554 |
+
NetC,
|
555 |
+
NetG,
|
556 |
+
train_loader,
|
557 |
+
eval_loader,
|
558 |
+
Cri = TempCombLoss(),
|
559 |
+
device='cuda',
|
560 |
+
dtype=torch.cuda.FloatTensor(),
|
561 |
+
init_lr=1e-4,
|
562 |
+
num_epochs=100,
|
563 |
+
eval_every=1,
|
564 |
+
ckpt_path='../ckpt/BayesCap',
|
565 |
+
T1=1e0,
|
566 |
+
T2=5e-2,
|
567 |
+
task=None,
|
568 |
+
):
|
569 |
+
NetC.to(device)
|
570 |
+
NetC.train()
|
571 |
+
NetG.to(device)
|
572 |
+
NetG.eval()
|
573 |
+
optimizer = torch.optim.Adam(list(NetC.parameters()), lr=init_lr)
|
574 |
+
optim_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, num_epochs)
|
575 |
+
|
576 |
+
score = -1e8
|
577 |
+
all_loss = []
|
578 |
+
for eph in range(num_epochs):
|
579 |
+
eph_loss = 0
|
580 |
+
with tqdm(train_loader, unit='batch') as tepoch:
|
581 |
+
for (idx, batch) in enumerate(tepoch):
|
582 |
+
if idx>2000:
|
583 |
+
break
|
584 |
+
tepoch.set_description('Epoch {}'.format(eph))
|
585 |
+
##
|
586 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
587 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
588 |
+
if task == 'inpainting':
|
589 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
590 |
+
xMask = xMask.to(device).type(dtype)
|
591 |
+
# pass them through the network
|
592 |
+
with torch.no_grad():
|
593 |
+
if task == 'inpainting':
|
594 |
+
_, xSR1 = NetG(xLR, xMask)
|
595 |
+
elif task == 'depth':
|
596 |
+
xSR1 = NetG(xLR)[("disp", 0)]
|
597 |
+
else:
|
598 |
+
xSR1 = NetG(xLR)
|
599 |
+
# with torch.autograd.set_detect_anomaly(True):
|
600 |
+
xSR = xSR1.clone()
|
601 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
602 |
+
# print(xSRC_alpha)
|
603 |
+
optimizer.zero_grad()
|
604 |
+
if task == 'depth':
|
605 |
+
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xSR, T1=T1, T2=T2)
|
606 |
+
else:
|
607 |
+
loss = Cri(xSRC_mu, xSRC_alpha, xSRC_beta, xHR, T1=T1, T2=T2)
|
608 |
+
# print(loss)
|
609 |
+
loss.backward()
|
610 |
+
optimizer.step()
|
611 |
+
##
|
612 |
+
eph_loss += loss.item()
|
613 |
+
tepoch.set_postfix(loss=loss.item())
|
614 |
+
eph_loss /= len(train_loader)
|
615 |
+
all_loss.append(eph_loss)
|
616 |
+
print('Avg. loss: {}'.format(eph_loss))
|
617 |
+
# evaluate and save the models
|
618 |
+
torch.save(NetC.state_dict(), ckpt_path+'_last.pth')
|
619 |
+
if eph%eval_every == 0:
|
620 |
+
curr_score = eval_BayesCap(
|
621 |
+
NetC,
|
622 |
+
NetG,
|
623 |
+
eval_loader,
|
624 |
+
device=device,
|
625 |
+
dtype=dtype,
|
626 |
+
task=task,
|
627 |
+
)
|
628 |
+
print('current score: {} | Last best score: {}'.format(curr_score, score))
|
629 |
+
if curr_score >= score:
|
630 |
+
score = curr_score
|
631 |
+
torch.save(NetC.state_dict(), ckpt_path+'_best.pth')
|
632 |
+
optim_scheduler.step()
|
633 |
+
|
634 |
+
#### get different uncertainty maps
|
635 |
+
def get_uncer_BayesCap(
|
636 |
+
NetC,
|
637 |
+
NetG,
|
638 |
+
xin,
|
639 |
+
task=None,
|
640 |
+
xMask=None,
|
641 |
+
):
|
642 |
+
with torch.no_grad():
|
643 |
+
if task == 'inpainting':
|
644 |
+
_, xSR = NetG(xin, xMask)
|
645 |
+
else:
|
646 |
+
xSR = NetG(xin)
|
647 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
648 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
649 |
+
b_map = xSRC_beta.to('cpu').data
|
650 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
651 |
+
|
652 |
+
return xSRvar
|
653 |
+
|
654 |
+
def get_uncer_TTDAp(
|
655 |
+
NetG,
|
656 |
+
xin,
|
657 |
+
p_mag=0.05,
|
658 |
+
num_runs=50,
|
659 |
+
task=None,
|
660 |
+
xMask=None,
|
661 |
+
):
|
662 |
+
list_xSR = []
|
663 |
+
with torch.no_grad():
|
664 |
+
for z in range(num_runs):
|
665 |
+
if task == 'inpainting':
|
666 |
+
_, xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin), xMask)
|
667 |
+
else:
|
668 |
+
xSRz = NetG(xin+p_mag*xin.max()*torch.randn_like(xin))
|
669 |
+
list_xSR.append(xSRz)
|
670 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
671 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
672 |
+
return xSRvar
|
673 |
+
|
674 |
+
def get_uncer_DO(
|
675 |
+
NetG,
|
676 |
+
xin,
|
677 |
+
dop=0.2,
|
678 |
+
num_runs=50,
|
679 |
+
task=None,
|
680 |
+
xMask=None,
|
681 |
+
):
|
682 |
+
list_xSR = []
|
683 |
+
with torch.no_grad():
|
684 |
+
for z in range(num_runs):
|
685 |
+
if task == 'inpainting':
|
686 |
+
_, xSRz = NetG(xin, xMask, dop=dop)
|
687 |
+
else:
|
688 |
+
xSRz = NetG(xin, dop=dop)
|
689 |
+
list_xSR.append(xSRz)
|
690 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
691 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
692 |
+
return xSRvar
|
693 |
+
|
694 |
+
################### Different eval functions
|
695 |
+
|
696 |
+
def eval_BayesCap(
|
697 |
+
NetC,
|
698 |
+
NetG,
|
699 |
+
eval_loader,
|
700 |
+
device='cuda',
|
701 |
+
dtype=torch.cuda.FloatTensor,
|
702 |
+
task=None,
|
703 |
+
xMask=None,
|
704 |
+
):
|
705 |
+
NetC.to(device)
|
706 |
+
NetC.eval()
|
707 |
+
NetG.to(device)
|
708 |
+
NetG.eval()
|
709 |
+
|
710 |
+
mean_ssim = 0
|
711 |
+
mean_psnr = 0
|
712 |
+
mean_mse = 0
|
713 |
+
mean_mae = 0
|
714 |
+
num_imgs = 0
|
715 |
+
list_error = []
|
716 |
+
list_var = []
|
717 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
718 |
+
for (idx, batch) in enumerate(tepoch):
|
719 |
+
tepoch.set_description('Validating ...')
|
720 |
+
##
|
721 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
722 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
723 |
+
if task == 'inpainting':
|
724 |
+
if xMask==None:
|
725 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
726 |
+
xMask = xMask.to(device).type(dtype)
|
727 |
+
else:
|
728 |
+
xMask = xMask.to(device).type(dtype)
|
729 |
+
# pass them through the network
|
730 |
+
with torch.no_grad():
|
731 |
+
if task == 'inpainting':
|
732 |
+
_, xSR = NetG(xLR, xMask)
|
733 |
+
elif task == 'depth':
|
734 |
+
xSR = NetG(xLR)[("disp", 0)]
|
735 |
+
else:
|
736 |
+
xSR = NetG(xLR)
|
737 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
738 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
739 |
+
b_map = xSRC_beta.to('cpu').data
|
740 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
741 |
+
n_batch = xSRC_mu.shape[0]
|
742 |
+
if task == 'depth':
|
743 |
+
xHR = xSR
|
744 |
+
for j in range(n_batch):
|
745 |
+
num_imgs += 1
|
746 |
+
mean_ssim += img_ssim(xSRC_mu[j], xHR[j])
|
747 |
+
mean_psnr += img_psnr(xSRC_mu[j], xHR[j])
|
748 |
+
mean_mse += img_mse(xSRC_mu[j], xHR[j])
|
749 |
+
mean_mae += img_mae(xSRC_mu[j], xHR[j])
|
750 |
+
|
751 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
752 |
+
|
753 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
754 |
+
var_map = xSRvar[j].to('cpu').data.reshape(-1)
|
755 |
+
list_error.extend(list(error_map.numpy()))
|
756 |
+
list_var.extend(list(var_map.numpy()))
|
757 |
+
##
|
758 |
+
mean_ssim /= num_imgs
|
759 |
+
mean_psnr /= num_imgs
|
760 |
+
mean_mse /= num_imgs
|
761 |
+
mean_mae /= num_imgs
|
762 |
+
print(
|
763 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
764 |
+
(
|
765 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
766 |
+
)
|
767 |
+
)
|
768 |
+
# print(len(list_error), len(list_var))
|
769 |
+
# print('UCE: ', get_UCE(list_error[::10], list_var[::10], num_bins=500)[1])
|
770 |
+
# print('C.Coeff: ', np.corrcoef(np.array(list_error[::10]), np.array(list_var[::10])))
|
771 |
+
return mean_ssim
|
772 |
+
|
773 |
+
def eval_TTDA_p(
|
774 |
+
NetG,
|
775 |
+
eval_loader,
|
776 |
+
device='cuda',
|
777 |
+
dtype=torch.cuda.FloatTensor,
|
778 |
+
p_mag=0.05,
|
779 |
+
num_runs=50,
|
780 |
+
task = None,
|
781 |
+
xMask = None,
|
782 |
+
):
|
783 |
+
NetG.to(device)
|
784 |
+
NetG.eval()
|
785 |
+
|
786 |
+
mean_ssim = 0
|
787 |
+
mean_psnr = 0
|
788 |
+
mean_mse = 0
|
789 |
+
mean_mae = 0
|
790 |
+
num_imgs = 0
|
791 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
792 |
+
for (idx, batch) in enumerate(tepoch):
|
793 |
+
tepoch.set_description('Validating ...')
|
794 |
+
##
|
795 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
796 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
797 |
+
# pass them through the network
|
798 |
+
list_xSR = []
|
799 |
+
with torch.no_grad():
|
800 |
+
if task=='inpainting':
|
801 |
+
_, xSR = NetG(xLR, xMask)
|
802 |
+
else:
|
803 |
+
xSR = NetG(xLR)
|
804 |
+
for z in range(num_runs):
|
805 |
+
xSRz = NetG(xLR+p_mag*xLR.max()*torch.randn_like(xLR))
|
806 |
+
list_xSR.append(xSRz)
|
807 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
808 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
809 |
+
n_batch = xSR.shape[0]
|
810 |
+
for j in range(n_batch):
|
811 |
+
num_imgs += 1
|
812 |
+
mean_ssim += img_ssim(xSR[j], xHR[j])
|
813 |
+
mean_psnr += img_psnr(xSR[j], xHR[j])
|
814 |
+
mean_mse += img_mse(xSR[j], xHR[j])
|
815 |
+
mean_mae += img_mae(xSR[j], xHR[j])
|
816 |
+
|
817 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
818 |
+
|
819 |
+
mean_ssim /= num_imgs
|
820 |
+
mean_psnr /= num_imgs
|
821 |
+
mean_mse /= num_imgs
|
822 |
+
mean_mae /= num_imgs
|
823 |
+
print(
|
824 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
825 |
+
(
|
826 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
827 |
+
)
|
828 |
+
)
|
829 |
+
|
830 |
+
return mean_ssim
|
831 |
+
|
832 |
+
def eval_DO(
|
833 |
+
NetG,
|
834 |
+
eval_loader,
|
835 |
+
device='cuda',
|
836 |
+
dtype=torch.cuda.FloatTensor,
|
837 |
+
dop=0.2,
|
838 |
+
num_runs=50,
|
839 |
+
task=None,
|
840 |
+
xMask=None,
|
841 |
+
):
|
842 |
+
NetG.to(device)
|
843 |
+
NetG.eval()
|
844 |
+
|
845 |
+
mean_ssim = 0
|
846 |
+
mean_psnr = 0
|
847 |
+
mean_mse = 0
|
848 |
+
mean_mae = 0
|
849 |
+
num_imgs = 0
|
850 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
851 |
+
for (idx, batch) in enumerate(tepoch):
|
852 |
+
tepoch.set_description('Validating ...')
|
853 |
+
##
|
854 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
855 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
856 |
+
# pass them through the network
|
857 |
+
list_xSR = []
|
858 |
+
with torch.no_grad():
|
859 |
+
if task == 'inpainting':
|
860 |
+
_, xSR = NetG(xLR, xMask)
|
861 |
+
else:
|
862 |
+
xSR = NetG(xLR)
|
863 |
+
for z in range(num_runs):
|
864 |
+
xSRz = NetG(xLR, dop=dop)
|
865 |
+
list_xSR.append(xSRz)
|
866 |
+
xSRmean = torch.mean(torch.cat(list_xSR, dim=0), dim=0).unsqueeze(0)
|
867 |
+
xSRvar = torch.mean(torch.var(torch.cat(list_xSR, dim=0), dim=0), dim=0).unsqueeze(0).unsqueeze(1)
|
868 |
+
n_batch = xSR.shape[0]
|
869 |
+
for j in range(n_batch):
|
870 |
+
num_imgs += 1
|
871 |
+
mean_ssim += img_ssim(xSR[j], xHR[j])
|
872 |
+
mean_psnr += img_psnr(xSR[j], xHR[j])
|
873 |
+
mean_mse += img_mse(xSR[j], xHR[j])
|
874 |
+
mean_mae += img_mae(xSR[j], xHR[j])
|
875 |
+
|
876 |
+
show_SR_w_uncer(xLR[j], xHR[j], xSR[j], xSRvar[j])
|
877 |
+
##
|
878 |
+
mean_ssim /= num_imgs
|
879 |
+
mean_psnr /= num_imgs
|
880 |
+
mean_mse /= num_imgs
|
881 |
+
mean_mae /= num_imgs
|
882 |
+
print(
|
883 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
884 |
+
(
|
885 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
886 |
+
)
|
887 |
+
)
|
888 |
+
|
889 |
+
return mean_ssim
|
890 |
+
|
891 |
+
|
892 |
+
############### compare all function
|
893 |
+
def compare_all(
|
894 |
+
NetC,
|
895 |
+
NetG,
|
896 |
+
eval_loader,
|
897 |
+
p_mag = 0.05,
|
898 |
+
dop = 0.2,
|
899 |
+
num_runs = 100,
|
900 |
+
device='cuda',
|
901 |
+
dtype=torch.cuda.FloatTensor,
|
902 |
+
task=None,
|
903 |
+
):
|
904 |
+
NetC.to(device)
|
905 |
+
NetC.eval()
|
906 |
+
NetG.to(device)
|
907 |
+
NetG.eval()
|
908 |
+
|
909 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
910 |
+
for (idx, batch) in enumerate(tepoch):
|
911 |
+
tepoch.set_description('Comparing ...')
|
912 |
+
##
|
913 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
914 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
915 |
+
if task == 'inpainting':
|
916 |
+
xMask = random_mask(xLR.shape[0], (xLR.shape[2], xLR.shape[3]))
|
917 |
+
xMask = xMask.to(device).type(dtype)
|
918 |
+
# pass them through the network
|
919 |
+
with torch.no_grad():
|
920 |
+
if task == 'inpainting':
|
921 |
+
_, xSR = NetG(xLR, xMask)
|
922 |
+
else:
|
923 |
+
xSR = NetG(xLR)
|
924 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR)
|
925 |
+
|
926 |
+
if task == 'inpainting':
|
927 |
+
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs, task='inpainting', xMask=xMask)
|
928 |
+
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs, task='inpainting', xMask=xMask)
|
929 |
+
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR, task='inpainting', xMask=xMask)
|
930 |
+
else:
|
931 |
+
xSRvar1 = get_uncer_TTDAp(NetG, xLR, p_mag=p_mag, num_runs=num_runs)
|
932 |
+
xSRvar2 = get_uncer_DO(NetG, xLR, dop=dop, num_runs=num_runs)
|
933 |
+
xSRvar3 = get_uncer_BayesCap(NetC, NetG, xLR)
|
934 |
+
|
935 |
+
print('bdg', xSRvar1.shape, xSRvar2.shape, xSRvar3.shape)
|
936 |
+
|
937 |
+
n_batch = xSR.shape[0]
|
938 |
+
for j in range(n_batch):
|
939 |
+
if task=='s':
|
940 |
+
show_SR_w_err(xLR[j], xHR[j], xSR[j])
|
941 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
|
942 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j])
|
943 |
+
if task=='d':
|
944 |
+
show_SR_w_err(xLR[j], xHR[j], 0.5*xSR[j]+0.5*xHR[j])
|
945 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42))
|
946 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
|
947 |
+
if task=='inpainting':
|
948 |
+
show_SR_w_err(xLR[j]*(1-xMask[j]), xHR[j], xSR[j], elim=(0,0.25), task='inpainting', xMask=xMask[j])
|
949 |
+
show_uncer4(xSRvar1[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.45), torch.pow(xSRvar1[j], 0.4))
|
950 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 0.8), xSRvar3[j])
|
951 |
+
if task=='m':
|
952 |
+
show_SR_w_err(xLR[j], xHR[j], xSR[j], elim=(0,0.04), task='m')
|
953 |
+
show_uncer4(0.4*xSRvar1[j]+0.6*xSRvar2[j], torch.sqrt(xSRvar1[j]), torch.pow(xSRvar1[j], 0.48), torch.pow(xSRvar1[j], 0.42), ulim=(0.02,0.15))
|
954 |
+
show_uncer4(xSRvar2[j], torch.sqrt(xSRvar2[j]), torch.pow(xSRvar3[j], 1.5), xSRvar3[j], ulim=(0.02,0.15))
|
955 |
+
|
956 |
+
|
957 |
+
################# Degrading Identity
|
958 |
+
def degrage_BayesCap_p(
|
959 |
+
NetC,
|
960 |
+
NetG,
|
961 |
+
eval_loader,
|
962 |
+
device='cuda',
|
963 |
+
dtype=torch.cuda.FloatTensor,
|
964 |
+
num_runs=50,
|
965 |
+
):
|
966 |
+
NetC.to(device)
|
967 |
+
NetC.eval()
|
968 |
+
NetG.to(device)
|
969 |
+
NetG.eval()
|
970 |
+
|
971 |
+
p_mag_list = [0, 0.05, 0.1, 0.15, 0.2]
|
972 |
+
list_s = []
|
973 |
+
list_p = []
|
974 |
+
list_u1 = []
|
975 |
+
list_u2 = []
|
976 |
+
list_c = []
|
977 |
+
for p_mag in p_mag_list:
|
978 |
+
mean_ssim = 0
|
979 |
+
mean_psnr = 0
|
980 |
+
mean_mse = 0
|
981 |
+
mean_mae = 0
|
982 |
+
num_imgs = 0
|
983 |
+
list_error = []
|
984 |
+
list_error2 = []
|
985 |
+
list_var = []
|
986 |
+
|
987 |
+
with tqdm(eval_loader, unit='batch') as tepoch:
|
988 |
+
for (idx, batch) in enumerate(tepoch):
|
989 |
+
tepoch.set_description('Validating ...')
|
990 |
+
##
|
991 |
+
xLR, xHR = batch[0].to(device), batch[1].to(device)
|
992 |
+
xLR, xHR = xLR.type(dtype), xHR.type(dtype)
|
993 |
+
# pass them through the network
|
994 |
+
with torch.no_grad():
|
995 |
+
xSR = NetG(xLR)
|
996 |
+
xSRC_mu, xSRC_alpha, xSRC_beta = NetC(xSR + p_mag*xSR.max()*torch.randn_like(xSR))
|
997 |
+
a_map = (1/(xSRC_alpha + 1e-5)).to('cpu').data
|
998 |
+
b_map = xSRC_beta.to('cpu').data
|
999 |
+
xSRvar = (a_map**2)*(torch.exp(torch.lgamma(3/(b_map + 1e-2)))/torch.exp(torch.lgamma(1/(b_map + 1e-2))))
|
1000 |
+
n_batch = xSRC_mu.shape[0]
|
1001 |
+
for j in range(n_batch):
|
1002 |
+
num_imgs += 1
|
1003 |
+
mean_ssim += img_ssim(xSRC_mu[j], xSR[j])
|
1004 |
+
mean_psnr += img_psnr(xSRC_mu[j], xSR[j])
|
1005 |
+
mean_mse += img_mse(xSRC_mu[j], xSR[j])
|
1006 |
+
mean_mae += img_mae(xSRC_mu[j], xSR[j])
|
1007 |
+
|
1008 |
+
error_map = torch.mean(torch.pow(torch.abs(xSR[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
1009 |
+
error_map2 = torch.mean(torch.pow(torch.abs(xSRC_mu[j]-xHR[j]),2), dim=0).to('cpu').data.reshape(-1)
|
1010 |
+
var_map = xSRvar[j].to('cpu').data.reshape(-1)
|
1011 |
+
list_error.extend(list(error_map.numpy()))
|
1012 |
+
list_error2.extend(list(error_map2.numpy()))
|
1013 |
+
list_var.extend(list(var_map.numpy()))
|
1014 |
+
##
|
1015 |
+
mean_ssim /= num_imgs
|
1016 |
+
mean_psnr /= num_imgs
|
1017 |
+
mean_mse /= num_imgs
|
1018 |
+
mean_mae /= num_imgs
|
1019 |
+
print(
|
1020 |
+
'Avg. SSIM: {} | Avg. PSNR: {} | Avg. MSE: {} | Avg. MAE: {}'.format
|
1021 |
+
(
|
1022 |
+
mean_ssim, mean_psnr, mean_mse, mean_mae
|
1023 |
+
)
|
1024 |
+
)
|
1025 |
+
uce1 = get_UCE(list_error[::100], list_var[::100], num_bins=200)[1]
|
1026 |
+
uce2 = get_UCE(list_error2[::100], list_var[::100], num_bins=200)[1]
|
1027 |
+
print('UCE1: ', uce1)
|
1028 |
+
print('UCE2: ', uce2)
|
1029 |
+
list_s.append(mean_ssim.item())
|
1030 |
+
list_p.append(mean_psnr.item())
|
1031 |
+
list_u1.append(uce1)
|
1032 |
+
list_u2.append(uce2)
|
1033 |
+
|
1034 |
+
plt.plot(list_s)
|
1035 |
+
plt.show()
|
1036 |
+
plt.plot(list_p)
|
1037 |
+
plt.show()
|
1038 |
+
|
1039 |
+
plt.plot(list_u1, label='wrt SR output')
|
1040 |
+
plt.plot(list_u2, label='wrt BayesCap output')
|
1041 |
+
plt.legend()
|
1042 |
+
plt.show()
|
1043 |
+
|
1044 |
+
sns.set_style('darkgrid')
|
1045 |
+
fig,ax = plt.subplots()
|
1046 |
+
# make a plot
|
1047 |
+
ax.plot(p_mag_list, list_s, color="red", marker="o")
|
1048 |
+
# set x-axis label
|
1049 |
+
ax.set_xlabel("Reducing faithfulness of BayesCap Reconstruction",fontsize=10)
|
1050 |
+
# set y-axis label
|
1051 |
+
ax.set_ylabel("SSIM btwn BayesCap and SRGAN outputs", color="red",fontsize=10)
|
1052 |
+
|
1053 |
+
# twin object for two different y-axis on the sample plot
|
1054 |
+
ax2=ax.twinx()
|
1055 |
+
# make a plot with different y-axis using second axis object
|
1056 |
+
ax2.plot(p_mag_list, list_u1, color="blue", marker="o", label='UCE wrt to error btwn SRGAN output and GT')
|
1057 |
+
ax2.plot(p_mag_list, list_u2, color="orange", marker="o", label='UCE wrt to error btwn BayesCap output and GT')
|
1058 |
+
ax2.set_ylabel("UCE", color="green", fontsize=10)
|
1059 |
+
plt.legend(fontsize=10)
|
1060 |
+
plt.tight_layout()
|
1061 |
+
plt.show()
|
1062 |
+
|
1063 |
+
################# DeepFill_v2
|
1064 |
+
|
1065 |
+
# ----------------------------------------
|
1066 |
+
# PATH processing
|
1067 |
+
# ----------------------------------------
|
1068 |
+
def text_readlines(filename):
|
1069 |
+
# Try to read a txt file and return a list.Return [] if there was a mistake.
|
1070 |
+
try:
|
1071 |
+
file = open(filename, 'r')
|
1072 |
+
except IOError:
|
1073 |
+
error = []
|
1074 |
+
return error
|
1075 |
+
content = file.readlines()
|
1076 |
+
# This for loop deletes the EOF (like \n)
|
1077 |
+
for i in range(len(content)):
|
1078 |
+
content[i] = content[i][:len(content[i])-1]
|
1079 |
+
file.close()
|
1080 |
+
return content
|
1081 |
+
|
1082 |
+
def savetxt(name, loss_log):
|
1083 |
+
np_loss_log = np.array(loss_log)
|
1084 |
+
np.savetxt(name, np_loss_log)
|
1085 |
+
|
1086 |
+
def get_files(path):
|
1087 |
+
# read a folder, return the complete path
|
1088 |
+
ret = []
|
1089 |
+
for root, dirs, files in os.walk(path):
|
1090 |
+
for filespath in files:
|
1091 |
+
ret.append(os.path.join(root, filespath))
|
1092 |
+
return ret
|
1093 |
+
|
1094 |
+
def get_names(path):
|
1095 |
+
# read a folder, return the image name
|
1096 |
+
ret = []
|
1097 |
+
for root, dirs, files in os.walk(path):
|
1098 |
+
for filespath in files:
|
1099 |
+
ret.append(filespath)
|
1100 |
+
return ret
|
1101 |
+
|
1102 |
+
def text_save(content, filename, mode = 'a'):
|
1103 |
+
# save a list to a txt
|
1104 |
+
# Try to save a list variable in txt file.
|
1105 |
+
file = open(filename, mode)
|
1106 |
+
for i in range(len(content)):
|
1107 |
+
file.write(str(content[i]) + '\n')
|
1108 |
+
file.close()
|
1109 |
+
|
1110 |
+
def check_path(path):
|
1111 |
+
if not os.path.exists(path):
|
1112 |
+
os.makedirs(path)
|
1113 |
+
|
1114 |
+
# ----------------------------------------
|
1115 |
+
# Validation and Sample at training
|
1116 |
+
# ----------------------------------------
|
1117 |
+
def save_sample_png(sample_folder, sample_name, img_list, name_list, pixel_max_cnt = 255):
|
1118 |
+
# Save image one-by-one
|
1119 |
+
for i in range(len(img_list)):
|
1120 |
+
img = img_list[i]
|
1121 |
+
# Recover normalization: * 255 because last layer is sigmoid activated
|
1122 |
+
img = img * 255
|
1123 |
+
# Process img_copy and do not destroy the data of img
|
1124 |
+
img_copy = img.clone().data.permute(0, 2, 3, 1)[0, :, :, :].cpu().numpy()
|
1125 |
+
img_copy = np.clip(img_copy, 0, pixel_max_cnt)
|
1126 |
+
img_copy = img_copy.astype(np.uint8)
|
1127 |
+
img_copy = cv2.cvtColor(img_copy, cv2.COLOR_RGB2BGR)
|
1128 |
+
# Save to certain path
|
1129 |
+
save_img_name = sample_name + '_' + name_list[i] + '.jpg'
|
1130 |
+
save_img_path = os.path.join(sample_folder, save_img_name)
|
1131 |
+
cv2.imwrite(save_img_path, img_copy)
|
1132 |
+
|
1133 |
+
def psnr(pred, target, pixel_max_cnt = 255):
|
1134 |
+
mse = torch.mul(target - pred, target - pred)
|
1135 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
1136 |
+
p = 20 * np.log10(pixel_max_cnt / rmse_avg)
|
1137 |
+
return p
|
1138 |
+
|
1139 |
+
def grey_psnr(pred, target, pixel_max_cnt = 255):
|
1140 |
+
pred = torch.sum(pred, dim = 0)
|
1141 |
+
target = torch.sum(target, dim = 0)
|
1142 |
+
mse = torch.mul(target - pred, target - pred)
|
1143 |
+
rmse_avg = (torch.mean(mse).item()) ** 0.5
|
1144 |
+
p = 20 * np.log10(pixel_max_cnt * 3 / rmse_avg)
|
1145 |
+
return p
|
1146 |
+
|
1147 |
+
def ssim(pred, target):
|
1148 |
+
pred = pred.clone().data.permute(0, 2, 3, 1).cpu().numpy()
|
1149 |
+
target = target.clone().data.permute(0, 2, 3, 1).cpu().numpy()
|
1150 |
+
target = target[0]
|
1151 |
+
pred = pred[0]
|
1152 |
+
ssim = skimage.measure.compare_ssim(target, pred, multichannel = True)
|
1153 |
+
return ssim
|
1154 |
+
|
1155 |
+
## for contextual attention
|
1156 |
+
|
1157 |
+
def extract_image_patches(images, ksizes, strides, rates, padding='same'):
|
1158 |
+
"""
|
1159 |
+
Extract patches from images and put them in the C output dimension.
|
1160 |
+
:param padding:
|
1161 |
+
:param images: [batch, channels, in_rows, in_cols]. A 4-D Tensor with shape
|
1162 |
+
:param ksizes: [ksize_rows, ksize_cols]. The size of the sliding window for
|
1163 |
+
each dimension of images
|
1164 |
+
:param strides: [stride_rows, stride_cols]
|
1165 |
+
:param rates: [dilation_rows, dilation_cols]
|
1166 |
+
:return: A Tensor
|
1167 |
+
"""
|
1168 |
+
assert len(images.size()) == 4
|
1169 |
+
assert padding in ['same', 'valid']
|
1170 |
+
batch_size, channel, height, width = images.size()
|
1171 |
+
|
1172 |
+
if padding == 'same':
|
1173 |
+
images = same_padding(images, ksizes, strides, rates)
|
1174 |
+
elif padding == 'valid':
|
1175 |
+
pass
|
1176 |
+
else:
|
1177 |
+
raise NotImplementedError('Unsupported padding type: {}.\
|
1178 |
+
Only "same" or "valid" are supported.'.format(padding))
|
1179 |
+
|
1180 |
+
unfold = torch.nn.Unfold(kernel_size=ksizes,
|
1181 |
+
dilation=rates,
|
1182 |
+
padding=0,
|
1183 |
+
stride=strides)
|
1184 |
+
patches = unfold(images)
|
1185 |
+
return patches # [N, C*k*k, L], L is the total number of such blocks
|
1186 |
+
|
1187 |
+
def same_padding(images, ksizes, strides, rates):
|
1188 |
+
assert len(images.size()) == 4
|
1189 |
+
batch_size, channel, rows, cols = images.size()
|
1190 |
+
out_rows = (rows + strides[0] - 1) // strides[0]
|
1191 |
+
out_cols = (cols + strides[1] - 1) // strides[1]
|
1192 |
+
effective_k_row = (ksizes[0] - 1) * rates[0] + 1
|
1193 |
+
effective_k_col = (ksizes[1] - 1) * rates[1] + 1
|
1194 |
+
padding_rows = max(0, (out_rows-1)*strides[0]+effective_k_row-rows)
|
1195 |
+
padding_cols = max(0, (out_cols-1)*strides[1]+effective_k_col-cols)
|
1196 |
+
# Pad the input
|
1197 |
+
padding_top = int(padding_rows / 2.)
|
1198 |
+
padding_left = int(padding_cols / 2.)
|
1199 |
+
padding_bottom = padding_rows - padding_top
|
1200 |
+
padding_right = padding_cols - padding_left
|
1201 |
+
paddings = (padding_left, padding_right, padding_top, padding_bottom)
|
1202 |
+
images = torch.nn.ZeroPad2d(paddings)(images)
|
1203 |
+
return images
|
1204 |
+
|
1205 |
+
def reduce_mean(x, axis=None, keepdim=False):
|
1206 |
+
if not axis:
|
1207 |
+
axis = range(len(x.shape))
|
1208 |
+
for i in sorted(axis, reverse=True):
|
1209 |
+
x = torch.mean(x, dim=i, keepdim=keepdim)
|
1210 |
+
return x
|
1211 |
+
|
1212 |
+
|
1213 |
+
def reduce_std(x, axis=None, keepdim=False):
|
1214 |
+
if not axis:
|
1215 |
+
axis = range(len(x.shape))
|
1216 |
+
for i in sorted(axis, reverse=True):
|
1217 |
+
x = torch.std(x, dim=i, keepdim=keepdim)
|
1218 |
+
return x
|
1219 |
+
|
1220 |
+
|
1221 |
+
def reduce_sum(x, axis=None, keepdim=False):
|
1222 |
+
if not axis:
|
1223 |
+
axis = range(len(x.shape))
|
1224 |
+
for i in sorted(axis, reverse=True):
|
1225 |
+
x = torch.sum(x, dim=i, keepdim=keepdim)
|
1226 |
+
return x
|
1227 |
+
|
1228 |
+
def random_mask(num_batch=1, mask_shape=(256,256)):
|
1229 |
+
list_mask = []
|
1230 |
+
for _ in range(num_batch):
|
1231 |
+
# rectangle mask
|
1232 |
+
image_height = mask_shape[0]
|
1233 |
+
image_width = mask_shape[1]
|
1234 |
+
max_delta_height = image_height//8
|
1235 |
+
max_delta_width = image_width//8
|
1236 |
+
height = image_height//4
|
1237 |
+
width = image_width//4
|
1238 |
+
max_t = image_height - height
|
1239 |
+
max_l = image_width - width
|
1240 |
+
t = random.randint(0, max_t)
|
1241 |
+
l = random.randint(0, max_l)
|
1242 |
+
# bbox = (t, l, height, width)
|
1243 |
+
h = random.randint(0, max_delta_height//2)
|
1244 |
+
w = random.randint(0, max_delta_width//2)
|
1245 |
+
mask = torch.zeros((1, 1, image_height, image_width))
|
1246 |
+
mask[:, :, t+h:t+height-h, l+w:l+width-w] = 1
|
1247 |
+
rect_mask = mask
|
1248 |
+
|
1249 |
+
# brush mask
|
1250 |
+
min_num_vertex = 4
|
1251 |
+
max_num_vertex = 12
|
1252 |
+
mean_angle = 2 * math.pi / 5
|
1253 |
+
angle_range = 2 * math.pi / 15
|
1254 |
+
min_width = 12
|
1255 |
+
max_width = 40
|
1256 |
+
H, W = image_height, image_width
|
1257 |
+
average_radius = math.sqrt(H*H+W*W) / 8
|
1258 |
+
mask = Image.new('L', (W, H), 0)
|
1259 |
+
|
1260 |
+
for _ in range(np.random.randint(1, 4)):
|
1261 |
+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
1262 |
+
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
1263 |
+
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
1264 |
+
angles = []
|
1265 |
+
vertex = []
|
1266 |
+
for i in range(num_vertex):
|
1267 |
+
if i % 2 == 0:
|
1268 |
+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
1269 |
+
else:
|
1270 |
+
angles.append(np.random.uniform(angle_min, angle_max))
|
1271 |
+
|
1272 |
+
h, w = mask.size
|
1273 |
+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
1274 |
+
for i in range(num_vertex):
|
1275 |
+
r = np.clip(
|
1276 |
+
np.random.normal(loc=average_radius, scale=average_radius//2),
|
1277 |
+
0, 2*average_radius)
|
1278 |
+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
1279 |
+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
1280 |
+
vertex.append((int(new_x), int(new_y)))
|
1281 |
+
|
1282 |
+
draw = ImageDraw.Draw(mask)
|
1283 |
+
width = int(np.random.uniform(min_width, max_width))
|
1284 |
+
draw.line(vertex, fill=255, width=width)
|
1285 |
+
for v in vertex:
|
1286 |
+
draw.ellipse((v[0] - width//2,
|
1287 |
+
v[1] - width//2,
|
1288 |
+
v[0] + width//2,
|
1289 |
+
v[1] + width//2),
|
1290 |
+
fill=255)
|
1291 |
+
|
1292 |
+
if np.random.normal() > 0:
|
1293 |
+
mask.transpose(Image.FLIP_LEFT_RIGHT)
|
1294 |
+
if np.random.normal() > 0:
|
1295 |
+
mask.transpose(Image.FLIP_TOP_BOTTOM)
|
1296 |
+
|
1297 |
+
mask = transforms.ToTensor()(mask)
|
1298 |
+
mask = mask.reshape((1, 1, H, W))
|
1299 |
+
brush_mask = mask
|
1300 |
+
|
1301 |
+
mask = torch.cat([rect_mask, brush_mask], dim=1).max(dim=1, keepdim=True)[0]
|
1302 |
+
list_mask.append(mask)
|
1303 |
+
mask = torch.cat(list_mask, dim=0)
|
1304 |
+
return mask
|