Spaces:
Runtime error
Runtime error
add intermediate outputs
Browse files- app.py +10 -0
- climategan/trainer.py +15 -3
- climategan_wrapper.py +2 -1
app.py
CHANGED
@@ -41,6 +41,8 @@ def predict(cg: ClimateGAN, api_key):
|
|
41 |
masked_input = output_dict["masked_input"]
|
42 |
wildfire = output_dict["wildfire"]
|
43 |
smog = output_dict["smog"]
|
|
|
|
|
44 |
|
45 |
climategan_flood = output_dict.get(
|
46 |
"climategan_flood",
|
@@ -62,6 +64,8 @@ def predict(cg: ClimateGAN, api_key):
|
|
62 |
return (
|
63 |
input_image,
|
64 |
masked_input,
|
|
|
|
|
65 |
climategan_flood,
|
66 |
stable_flood,
|
67 |
stable_copy_flood,
|
@@ -127,6 +131,12 @@ if __name__ == "__main__":
|
|
127 |
outputs.append(
|
128 |
gr.outputs.Image(type="numpy", label="Masked input image"),
|
129 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
with gr.Row():
|
131 |
outputs.append(
|
132 |
gr.outputs.Image(type="numpy", label="ClimateGAN-Flooded image"),
|
|
|
41 |
masked_input = output_dict["masked_input"]
|
42 |
wildfire = output_dict["wildfire"]
|
43 |
smog = output_dict["smog"]
|
44 |
+
depth = np.repeat(output_dict["depth"][..., None], 3, axis=-1)
|
45 |
+
segmentation = output_dict["segmentation"]
|
46 |
|
47 |
climategan_flood = output_dict.get(
|
48 |
"climategan_flood",
|
|
|
64 |
return (
|
65 |
input_image,
|
66 |
masked_input,
|
67 |
+
segmentation,
|
68 |
+
depth,
|
69 |
climategan_flood,
|
70 |
stable_flood,
|
71 |
stable_copy_flood,
|
|
|
131 |
outputs.append(
|
132 |
gr.outputs.Image(type="numpy", label="Masked input image"),
|
133 |
)
|
134 |
+
outputs.append(
|
135 |
+
gr.outputs.Image(type="numpy", label="Segmentation map"),
|
136 |
+
)
|
137 |
+
outputs.append(
|
138 |
+
gr.outputs.Image(type="numpy", label="Depth map"),
|
139 |
+
)
|
140 |
with gr.Row():
|
141 |
outputs.append(
|
142 |
gr.outputs.Image(type="numpy", label="ClimateGAN-Flooded image"),
|
climategan/trainer.py
CHANGED
@@ -22,7 +22,8 @@ from torch import autograd, sigmoid, softmax
|
|
22 |
from torch.cuda.amp import GradScaler, autocast
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
-
from climategan.data import get_all_loaders
|
|
|
26 |
from climategan.discriminator import OmniDiscriminator, create_discriminator
|
27 |
from climategan.eval_metrics import accuracy, mIOU
|
28 |
from climategan.fid import compute_val_fid
|
@@ -38,6 +39,7 @@ from climategan.tutils import (
|
|
38 |
get_WGAN_gradient,
|
39 |
lrgb2srgb,
|
40 |
normalize,
|
|
|
41 |
print_num_parameters,
|
42 |
shuffle_batch_tuple,
|
43 |
srgb2lrgb,
|
@@ -226,7 +228,7 @@ class Trainer:
|
|
226 |
cloudy=True,
|
227 |
auto_resize_640=False,
|
228 |
ignore_event=set(),
|
229 |
-
|
230 |
):
|
231 |
"""
|
232 |
Create a dictionnary of events from a numpy or tensor,
|
@@ -331,10 +333,20 @@ class Trainer:
|
|
331 |
smog = (smog * 255).astype(np.uint8)
|
332 |
output_data["smog"] = smog
|
333 |
|
334 |
-
if
|
335 |
output_data["mask"] = (
|
336 |
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
337 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
338 |
|
339 |
return output_data
|
340 |
|
|
|
22 |
from torch.cuda.amp import GradScaler, autocast
|
23 |
from tqdm import tqdm
|
24 |
|
25 |
+
from climategan.data import get_all_loaders, decode_segmap_merged_labels
|
26 |
+
|
27 |
from climategan.discriminator import OmniDiscriminator, create_discriminator
|
28 |
from climategan.eval_metrics import accuracy, mIOU
|
29 |
from climategan.fid import compute_val_fid
|
|
|
39 |
get_WGAN_gradient,
|
40 |
lrgb2srgb,
|
41 |
normalize,
|
42 |
+
normalize_tensor,
|
43 |
print_num_parameters,
|
44 |
shuffle_batch_tuple,
|
45 |
srgb2lrgb,
|
|
|
228 |
cloudy=True,
|
229 |
auto_resize_640=False,
|
230 |
ignore_event=set(),
|
231 |
+
return_intermediates=False,
|
232 |
):
|
233 |
"""
|
234 |
Create a dictionnary of events from a numpy or tensor,
|
|
|
333 |
smog = (smog * 255).astype(np.uint8)
|
334 |
output_data["smog"] = smog
|
335 |
|
336 |
+
if return_intermediates:
|
337 |
output_data["mask"] = (
|
338 |
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
339 |
)
|
340 |
+
output_data["depth"] = (
|
341 |
+
normalize_tensor(depth).cpu().squeeze(1).numpy().astype(np.uint8) * 255
|
342 |
+
)
|
343 |
+
output_data["segmentation"] = (
|
344 |
+
decode_segmap_merged_labels(segmentation, "r", False)
|
345 |
+
.cpu()
|
346 |
+
.permute(0, 2, 3, 1)
|
347 |
+
.numpy()
|
348 |
+
.astype(np.uint8)
|
349 |
+
)
|
350 |
|
351 |
return output_data
|
352 |
|
climategan_wrapper.py
CHANGED
@@ -15,6 +15,7 @@ from skimage.transform import resize
|
|
15 |
|
16 |
from climategan.trainer import Trainer
|
17 |
|
|
|
18 |
CUDA = torch.cuda.is_available()
|
19 |
|
20 |
|
@@ -313,7 +314,7 @@ class ClimateGAN:
|
|
313 |
bin_value=0.5,
|
314 |
half=CUDA,
|
315 |
ignore_event=ignore_event,
|
316 |
-
|
317 |
)
|
318 |
|
319 |
outputs["input"] = uint8(images, True)
|
|
|
15 |
|
16 |
from climategan.trainer import Trainer
|
17 |
|
18 |
+
|
19 |
CUDA = torch.cuda.is_available()
|
20 |
|
21 |
|
|
|
314 |
bin_value=0.5,
|
315 |
half=CUDA,
|
316 |
ignore_event=ignore_event,
|
317 |
+
return_intermediates=True,
|
318 |
)
|
319 |
|
320 |
outputs["input"] = uint8(images, True)
|