vict0rsch commited on
Commit
9da944e
·
1 Parent(s): 87ed0da

add intermediate outputs

Browse files
Files changed (3) hide show
  1. app.py +10 -0
  2. climategan/trainer.py +15 -3
  3. climategan_wrapper.py +2 -1
app.py CHANGED
@@ -41,6 +41,8 @@ def predict(cg: ClimateGAN, api_key):
41
  masked_input = output_dict["masked_input"]
42
  wildfire = output_dict["wildfire"]
43
  smog = output_dict["smog"]
 
 
44
 
45
  climategan_flood = output_dict.get(
46
  "climategan_flood",
@@ -62,6 +64,8 @@ def predict(cg: ClimateGAN, api_key):
62
  return (
63
  input_image,
64
  masked_input,
 
 
65
  climategan_flood,
66
  stable_flood,
67
  stable_copy_flood,
@@ -127,6 +131,12 @@ if __name__ == "__main__":
127
  outputs.append(
128
  gr.outputs.Image(type="numpy", label="Masked input image"),
129
  )
 
 
 
 
 
 
130
  with gr.Row():
131
  outputs.append(
132
  gr.outputs.Image(type="numpy", label="ClimateGAN-Flooded image"),
 
41
  masked_input = output_dict["masked_input"]
42
  wildfire = output_dict["wildfire"]
43
  smog = output_dict["smog"]
44
+ depth = np.repeat(output_dict["depth"][..., None], 3, axis=-1)
45
+ segmentation = output_dict["segmentation"]
46
 
47
  climategan_flood = output_dict.get(
48
  "climategan_flood",
 
64
  return (
65
  input_image,
66
  masked_input,
67
+ segmentation,
68
+ depth,
69
  climategan_flood,
70
  stable_flood,
71
  stable_copy_flood,
 
131
  outputs.append(
132
  gr.outputs.Image(type="numpy", label="Masked input image"),
133
  )
134
+ outputs.append(
135
+ gr.outputs.Image(type="numpy", label="Segmentation map"),
136
+ )
137
+ outputs.append(
138
+ gr.outputs.Image(type="numpy", label="Depth map"),
139
+ )
140
  with gr.Row():
141
  outputs.append(
142
  gr.outputs.Image(type="numpy", label="ClimateGAN-Flooded image"),
climategan/trainer.py CHANGED
@@ -22,7 +22,8 @@ from torch import autograd, sigmoid, softmax
22
  from torch.cuda.amp import GradScaler, autocast
23
  from tqdm import tqdm
24
 
25
- from climategan.data import get_all_loaders
 
26
  from climategan.discriminator import OmniDiscriminator, create_discriminator
27
  from climategan.eval_metrics import accuracy, mIOU
28
  from climategan.fid import compute_val_fid
@@ -38,6 +39,7 @@ from climategan.tutils import (
38
  get_WGAN_gradient,
39
  lrgb2srgb,
40
  normalize,
 
41
  print_num_parameters,
42
  shuffle_batch_tuple,
43
  srgb2lrgb,
@@ -226,7 +228,7 @@ class Trainer:
226
  cloudy=True,
227
  auto_resize_640=False,
228
  ignore_event=set(),
229
- return_masks=False,
230
  ):
231
  """
232
  Create a dictionnary of events from a numpy or tensor,
@@ -331,10 +333,20 @@ class Trainer:
331
  smog = (smog * 255).astype(np.uint8)
332
  output_data["smog"] = smog
333
 
334
- if return_masks:
335
  output_data["mask"] = (
336
  ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
337
  )
 
 
 
 
 
 
 
 
 
 
338
 
339
  return output_data
340
 
 
22
  from torch.cuda.amp import GradScaler, autocast
23
  from tqdm import tqdm
24
 
25
+ from climategan.data import get_all_loaders, decode_segmap_merged_labels
26
+
27
  from climategan.discriminator import OmniDiscriminator, create_discriminator
28
  from climategan.eval_metrics import accuracy, mIOU
29
  from climategan.fid import compute_val_fid
 
39
  get_WGAN_gradient,
40
  lrgb2srgb,
41
  normalize,
42
+ normalize_tensor,
43
  print_num_parameters,
44
  shuffle_batch_tuple,
45
  srgb2lrgb,
 
228
  cloudy=True,
229
  auto_resize_640=False,
230
  ignore_event=set(),
231
+ return_intermediates=False,
232
  ):
233
  """
234
  Create a dictionnary of events from a numpy or tensor,
 
333
  smog = (smog * 255).astype(np.uint8)
334
  output_data["smog"] = smog
335
 
336
+ if return_intermediates:
337
  output_data["mask"] = (
338
  ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
339
  )
340
+ output_data["depth"] = (
341
+ normalize_tensor(depth).cpu().squeeze(1).numpy().astype(np.uint8) * 255
342
+ )
343
+ output_data["segmentation"] = (
344
+ decode_segmap_merged_labels(segmentation, "r", False)
345
+ .cpu()
346
+ .permute(0, 2, 3, 1)
347
+ .numpy()
348
+ .astype(np.uint8)
349
+ )
350
 
351
  return output_data
352
 
climategan_wrapper.py CHANGED
@@ -15,6 +15,7 @@ from skimage.transform import resize
15
 
16
  from climategan.trainer import Trainer
17
 
 
18
  CUDA = torch.cuda.is_available()
19
 
20
 
@@ -313,7 +314,7 @@ class ClimateGAN:
313
  bin_value=0.5,
314
  half=CUDA,
315
  ignore_event=ignore_event,
316
- return_masks=True,
317
  )
318
 
319
  outputs["input"] = uint8(images, True)
 
15
 
16
  from climategan.trainer import Trainer
17
 
18
+
19
  CUDA = torch.cuda.is_available()
20
 
21
 
 
314
  bin_value=0.5,
315
  half=CUDA,
316
  ignore_event=ignore_event,
317
+ return_intermediates=True,
318
  )
319
 
320
  outputs["input"] = uint8(images, True)