+```
+
+### Tests
+
+Run tests by executing `python test_trainer.py`. You can add `--no_delete` not to delete the comet experiment at exit and inspect uploads.
+
+Write tests as scenarios by adding to the list `test_scenarios` in the file. A scenario is a dict of overrides over the base opts in `shared/trainer/defaults.yaml`. You can create special flags for the scenario by adding keys which start with `__`. For instance, `__doc` is a mandatory key in any scenario describing it succinctly.
+
+## Resources
+
+[Tricks and Tips for Training a GAN](https://chloes-dl.com/2019/11/19/tricks-and-tips-for-training-a-gan/)
+[GAN Hacks](https://github.com/soumith/ganhacks)
+[Keep Calm and train a GAN. Pitfalls and Tips on training Generative Adversarial Networks](https://medium.com/@utk.is.here/keep-calm-and-train-a-gan-pitfalls-and-tips-on-training-generative-adversarial-networks-edd529764aa9)
+
+## Example
+
+**Inference: computing floods**
+
+```python
+from pathlib import Path
+from skimage.io import imsave
+from tqdm import tqdm
+
+from climategan.trainer import Trainer
+from climategan.utils import find_images
+from climategan.tutils import tensor_ims_to_np_uint8s
+from climategan.transforms import PrepareInference
+
+
+model_path = "some/path/to/output/folder" # not .ckpt
+input_folder = "path/to/a/folder/with/images"
+output_path = "path/where/images/will/be/written"
+
+# resume trainer
+trainer = Trainer.resume_from_path(model_path, new_exp=None, inference=True)
+
+# find paths for all images in the input folder. There is a recursive option.
+im_paths = sorted(find_images(input_folder), key=lambda x: x.name)
+
+# Load images into tensors
+# * smaller side resized to 640 - keeping aspect ratio
+# * then longer side is cropped in the center
+# * result is a 1x3x640x640 float tensor in [-1; 1]
+xs = PrepareInference()(im_paths)
+
+# send to device
+xs = [x.to(trainer.device) for x in xs]
+
+# compute flood
+# * compute mask
+# * binarize mask if bin_value > 0
+# * paint x using this mask
+ys = [trainer.compute_flood(x, bin_value=0.5) for x in tqdm(xs)]
+
+# convert 1x3x640x640 float tensors in [-1; 1] into 640x640x3 numpy arrays in [0, 255]
+np_ys = [tensor_ims_to_np_uint8s(y) for y in tqdm(ys)]
+
+# write images
+for i, n in tqdm(zip(im_paths, np_ys), total=len(im_paths)):
+ imsave(Path(output_path) / i.name, n)
+```
+
+## Release process
+
+In the `release/` folder
+* create a `model/` folder
+* create folders `model/masker/` and `model/painter/`
+* add the climategan code in `release/`: `git clone git@github.com:cc-ai/climategan.git`
+* move the code to `release/`: `cp climategan/* . && rm -rf climategan`
+* update `model/masker/opts/events` with `events:` from `shared/trainer/opts.yaml`
+* update `model/masker/opts/val.val_painter` to `"model/painter/checkpoints/latest_ckpt.pth"`
+* update `model/masker/opts/load_paths.m` to `"model/masker/checkpoints/latest_ckpt.pth"`
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a6e69a289efd267620bf894ead2c297ff05f943
--- /dev/null
+++ b/app.py
@@ -0,0 +1,70 @@
+# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/app.py # noqa: E501
+# thank you @NimaBoscarino
+
+import os
+import gradio as gr
+import googlemaps
+from skimage import io
+from urllib import parse
+from inferences import ClimateGAN
+
+
+def predict(api_key):
+ def _predict(*args):
+ print("args: ", args)
+ image = place = None
+ if len(args) == 1:
+ image = args[0]
+ else:
+ assert len(args) == 2, "Unknown number of inputs {}".format(len(args))
+ image, place = args
+
+ if api_key and place:
+ geocode_result = gmaps.geocode(place)
+
+ address = geocode_result[0]["formatted_address"]
+ static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=640x640&location={parse.quote(address)}&source=outdoor&key={api_key}"
+ img_np = io.imread(static_map_url)
+ else:
+ img_np = image
+ flood, wildfire, smog = model.inference(img_np)
+ return img_np, flood, wildfire, smog
+
+ return _predict
+
+
+if __name__ == "__main__":
+
+ api_key = os.environ.get("GMAPS_API_KEY")
+ gmaps = None
+ if api_key is not None:
+ gmaps = googlemaps.Client(key=api_key)
+
+ model = ClimateGAN(model_path="config/model/masker")
+
+ inputs = inputs = [gr.inputs.Image(label="Input Image")]
+ if api_key:
+ inputs += [gr.inputs.Textbox(label="Address or place name")]
+
+ gr.Interface(
+ predict(api_key),
+ inputs=[
+ gr.inputs.Textbox(label="Address or place name"),
+ gr.inputs.Image(label="Input Image"),
+ ],
+ outputs=[
+ gr.outputs.Image(type="numpy", label="Original image"),
+ gr.outputs.Image(type="numpy", label="Flooding"),
+ gr.outputs.Image(type="numpy", label="Wildfire"),
+ gr.outputs.Image(type="numpy", label="Smog"),
+ ],
+ title="ClimateGAN: Visualize Climate Change",
+ description='Climate change does not impact everyone equally. This Space shows the effects of the climate emergency, "one address at a time". Visit the original experience at ThisClimateDoesNotExist.com.
Enter an address or place name, and ClimateGAN will generate images showing how the location could be impacted by flooding, wildfires, or smog.', # noqa: E501
+ article="This project is an unofficial clone of ThisClimateDoesNotExist | ClimateGAN GitHub Repo
", # noqa: E501
+ # examples=[
+ # "Vancouver Art Gallery",
+ # "Chicago Bean",
+ # "Duomo Siracusa",
+ # ],
+ css=".footer{display:none !important}",
+ ).launch()
diff --git a/apply_events.py b/apply_events.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7df616f9af7a3c8877af50b017f1bbc0df78889
--- /dev/null
+++ b/apply_events.py
@@ -0,0 +1,642 @@
+import argparse
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-b",
+ "--batch_size",
+ type=int,
+ default=4,
+ help="Batch size to process input images to events. Defaults to 4",
+ )
+ parser.add_argument(
+ "-i",
+ "--images_paths",
+ type=str,
+ required=True,
+ help="Path to a directory with image files",
+ )
+ parser.add_argument(
+ "-o",
+ "--output_path",
+ type=str,
+ default=None,
+ help="Path to a directory were events should be written. "
+ + "Will NOT write anything to disk if this flag is not used.",
+ )
+ parser.add_argument(
+ "-s",
+ "--save_input",
+ action="store_true",
+ default=False,
+ help="Binary flag to include the input image to the model (after crop and"
+ + " resize) in the images written or uploaded (depending on saving options.)",
+ )
+ parser.add_argument(
+ "-r",
+ "--resume_path",
+ type=str,
+ default=None,
+ help="Path to a directory containing the trainer to resume."
+ + " In particular it must contain `opts.yam` and `checkpoints/`."
+ + " Typically this points to a Masker, which holds the path to a"
+ + " Painter in its opts",
+ )
+ parser.add_argument(
+ "--no_time",
+ action="store_true",
+ default=False,
+ help="Binary flag to prevent the timing of operations.",
+ )
+ parser.add_argument(
+ "-f",
+ "--flood_mask_binarization",
+ type=float,
+ default=0.5,
+ help="Value to use to binarize masks (mask > value). "
+ + "Set to -1 to use soft masks (not binarized). Defaults to 0.5.",
+ )
+ parser.add_argument(
+ "-t",
+ "--target_size",
+ type=int,
+ default=640,
+ help="Output image size (when not using `keep_ratio_128`): images are resized"
+ + " such that their smallest side is `target_size` then cropped in the middle"
+ + " of the largest side such that the resulting input image (and output images)"
+ + " has height and width `target_size x target_size`. **Must** be a multiple of"
+ + " 2^7=128 (up/downscaling inside the models). Defaults to 640.",
+ )
+ parser.add_argument(
+ "--half",
+ action="store_true",
+ default=False,
+ help="Binary flag to use half precision (float16). Defaults to False.",
+ )
+ parser.add_argument(
+ "-n",
+ "--n_images",
+ default=-1,
+ type=int,
+ help="Limit the number of images processed (if you have 100 images in "
+ + "a directory but n is 10 then only the first 10 images will be loaded"
+ + " for processing)",
+ )
+ parser.add_argument(
+ "--no_conf",
+ action="store_true",
+ default=False,
+ help="disable writing the apply_events hash and command in the output folder",
+ )
+ parser.add_argument(
+ "--overwrite",
+ action="store_true",
+ default=False,
+ help="Do not check for existing outdir, i.e. force overwrite"
+ + " potentially existing files in the output path",
+ )
+ parser.add_argument(
+ "--no_cloudy",
+ action="store_true",
+ default=False,
+ help="Prevent the use of the cloudy intermediate"
+ + " image to create the flood image. Rendering will"
+ + " be more colorful but may seem less realistic",
+ )
+ parser.add_argument(
+ "--keep_ratio_128",
+ action="store_true",
+ default=False,
+ help="When loading the input images, resize and crop them in order for their "
+ + "dimensions to match the closest multiples"
+ + " of 128. Will force a batch size of 1 since images"
+ + " now have different dimensions. "
+ + "Use --max_im_width to cap the resulting dimensions.",
+ )
+ parser.add_argument(
+ "--fuse",
+ action="store_true",
+ default=False,
+ help="Use batch norm fusion to speed up inference",
+ )
+ parser.add_argument(
+ "--save_masks",
+ action="store_true",
+ default=False,
+ help="Save output masks along events",
+ )
+ parser.add_argument(
+ "-m",
+ "--max_im_width",
+ type=int,
+ default=-1,
+ help="When using --keep_ratio_128, some images may still be too large. Use "
+ + "--max_im_width to cap the resized image's width. Defaults to -1 (no cap).",
+ )
+ parser.add_argument(
+ "--upload",
+ action="store_true",
+ help="Upload to comet.ml in a project called `climategan-apply`",
+ )
+ parser.add_argument(
+ "--zip_outdir",
+ "-z",
+ action="store_true",
+ help="Zip the output directory as '{outdir.parent}/{outdir.name}.zip'",
+ )
+ return parser.parse_args()
+
+
+args = parse_args()
+
+
+print("\n• Imports\n")
+import time
+
+import_time = time.time()
+import sys
+import shutil
+from collections import OrderedDict
+from pathlib import Path
+
+import comet_ml # noqa: F401
+import torch
+import numpy as np
+import skimage.io as io
+from skimage.color import rgba2rgb
+from skimage.transform import resize
+from tqdm import tqdm
+
+from climategan.trainer import Trainer
+from climategan.bn_fusion import bn_fuse
+from climategan.tutils import print_num_parameters
+from climategan.utils import Timer, find_images, get_git_revision_hash, to_128, resolve
+
+import_time = time.time() - import_time
+
+
+def to_m1_p1(img, i):
+ """
+ rescales a [0, 1] image to [-1, +1]
+
+ Args:
+ img (np.array): float32 numpy array of an image in [0, 1]
+ i (int): Index of the image being rescaled
+
+ Raises:
+ ValueError: If the image is not in [0, 1]
+
+ Returns:
+ np.array(np.float32): array in [-1, +1]
+ """
+ if img.min() >= 0 and img.max() <= 1:
+ return (img.astype(np.float32) - 0.5) * 2
+ raise ValueError(f"Data range mismatch for image {i} : ({img.min()}, {img.max()})")
+
+
+def uint8(array):
+ """
+ convert an array to np.uint8 (does not rescale or anything else than changing dtype)
+
+ Args:
+ array (np.array): array to modify
+
+ Returns:
+ np.array(np.uint8): converted array
+ """
+ return array.astype(np.uint8)
+
+
+def resize_and_crop(img, to=640):
+ """
+ Resizes an image so that it keeps the aspect ratio and the smallest dimensions
+ is `to`, then crops this resized image in its center so that the output is `to x to`
+ without aspect ratio distortion
+
+ Args:
+ img (np.array): np.uint8 255 image
+
+ Returns:
+ np.array: [0, 1] np.float32 image
+ """
+ # resize keeping aspect ratio: smallest dim is 640
+ h, w = img.shape[:2]
+ if h < w:
+ size = (to, int(to * w / h))
+ else:
+ size = (int(to * h / w), to)
+
+ r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
+ r_img = uint8(r_img)
+
+ # crop in the center
+ H, W = r_img.shape[:2]
+
+ top = (H - to) // 2
+ left = (W - to) // 2
+
+ rc_img = r_img[top : top + to, left : left + to, :]
+
+ return rc_img / 255.0
+
+
+def print_time(text, time_series, purge=-1):
+ """
+ Print a timeseries's mean and std with a label
+
+ Args:
+ text (str): label of the time series
+ time_series (list): list of timings
+ purge (int, optional): ignore first n values of time series. Defaults to -1.
+ """
+ if not time_series:
+ return
+
+ if purge > 0 and len(time_series) > purge:
+ time_series = time_series[purge:]
+
+ m = np.mean(time_series)
+ s = np.std(time_series)
+
+ print(
+ f"{text.capitalize() + ' ':.<26} {m:.5f}"
+ + (f" +/- {s:.5f}" if len(time_series) > 1 else "")
+ )
+
+
+def print_store(store, purge=-1):
+ """
+ Pretty-print time series store
+
+ Args:
+ store (dict): maps string keys to lists of times
+ purge (int, optional): ignore first n values of time series. Defaults to -1.
+ """
+ singles = OrderedDict({k: v for k, v in store.items() if len(v) == 1})
+ multiples = OrderedDict({k: v for k, v in store.items() if len(v) > 1})
+ empties = {k: v for k, v in store.items() if len(v) == 0}
+
+ if empties:
+ print("Ignoring empty stores ", ", ".join(empties.keys()))
+ print()
+
+ for k in singles:
+ print_time(k, singles[k], purge)
+
+ print()
+ print("Unit: s/batch")
+ for k in multiples:
+ print_time(k, multiples[k], purge)
+ print()
+
+
+def write_apply_config(out):
+ """
+ Saves the args to `apply_events.py` in a text file for future reference
+ """
+ cwd = Path.cwd().expanduser().resolve()
+ command = f"cd {str(cwd)}\n"
+ command += " ".join(sys.argv)
+ git_hash = get_git_revision_hash()
+ with (out / "command.txt").open("w") as f:
+ f.write(command)
+ with (out / "hash.txt").open("w") as f:
+ f.write(git_hash)
+
+
+def get_outdir_name(half, keep_ratio, max_im_width, target_size, bin_value, cloudy):
+ """
+ Create the output directory's name based on uer-provided arguments
+ """
+ name_items = []
+ if half:
+ name_items.append("half")
+ if keep_ratio:
+ name_items.append("AR")
+ if max_im_width and keep_ratio:
+ name_items.append(f"{max_im_width}")
+ if target_size and not keep_ratio:
+ name_items.append("S")
+ name_items.append(f"{target_size}")
+ if bin_value != 0.5:
+ name_items.append(f"bin{bin_value}")
+ if not cloudy:
+ name_items.append("no_cloudy")
+
+ return "-".join(name_items)
+
+
+def make_outdir(
+ outdir, overwrite, half, keep_ratio, max_im_width, target_size, bin_value, cloudy
+):
+ """
+ Creates the output directory if it does not exist. If it does exist,
+ prompts the user for confirmation (except if `overwrite` is True).
+ If the output directory's name is "_auto_" then it is created as:
+ outdir.parent / get_outdir_name(...)
+ """
+ if outdir.name == "_auto_":
+ outdir = outdir.parent / get_outdir_name(
+ half, keep_ratio, max_im_width, target_size, bin_value, cloudy
+ )
+ if outdir.exists() and not overwrite:
+ print(
+ f"\nWARNING: outdir ({str(outdir)}) already exists."
+ + " Files with existing names will be overwritten"
+ )
+ if "n" in input(">>> Continue anyway? [y / n] (default: y) : "):
+ print("Interrupting execution from user input.")
+ sys.exit()
+ print()
+ outdir.mkdir(exist_ok=True, parents=True)
+ return outdir
+
+
+def get_time_stores(import_time):
+ return OrderedDict(
+ {
+ "imports": [import_time],
+ "setup": [],
+ "data pre-processing": [],
+ "encode": [],
+ "mask": [],
+ "flood": [],
+ "depth": [],
+ "segmentation": [],
+ "smog": [],
+ "wildfire": [],
+ "all events": [],
+ "numpy": [],
+ "inference on all images": [],
+ "write": [],
+ }
+ )
+
+
+if __name__ == "__main__":
+
+ # -----------------------------------------
+ # ----- Initialize script variables -----
+ # -----------------------------------------
+ print(
+ "• Using args\n\n"
+ + "\n".join(["{:25}: {}".format(k, v) for k, v in vars(args).items()]),
+ )
+
+ batch_size = args.batch_size
+ bin_value = args.flood_mask_binarization
+ cloudy = not args.no_cloudy
+ fuse = args.fuse
+ half = args.half
+ save_masks = args.save_masks
+ images_paths = resolve(args.images_paths)
+ keep_ratio = args.keep_ratio_128
+ max_im_width = args.max_im_width
+ n_images = args.n_images
+ outdir = resolve(args.output_path) if args.output_path is not None else None
+ resume_path = args.resume_path
+ target_size = args.target_size
+ time_inference = not args.no_time
+ upload = args.upload
+ zip_outdir = args.zip_outdir
+
+ # -------------------------------------
+ # ----- Validate size arguments -----
+ # -------------------------------------
+ if keep_ratio:
+ if target_size != 640:
+ print(
+ "\nWARNING: using --keep_ratio_128 overwrites target_size"
+ + " which is ignored."
+ )
+ if batch_size != 1:
+ print("\nWARNING: batch_size overwritten to 1 when using keep_ratio_128")
+ batch_size = 1
+ if max_im_width > 0 and max_im_width % 128 != 0:
+ new_im_width = int(max_im_width / 128) * 128
+ print("\nWARNING: max_im_width should be <0 or a multiple of 128.")
+ print(
+ " Was {} but is now overwritten to {}".format(
+ max_im_width, new_im_width
+ )
+ )
+ max_im_width = new_im_width
+ else:
+ if target_size % 128 != 0:
+ print(f"\nWarning: target size {target_size} is not a multiple of 128.")
+ target_size = target_size - (target_size % 128)
+ print(f"Setting target_size to {target_size}.")
+
+ # -------------------------------------
+ # ----- Create output directory -----
+ # -------------------------------------
+ if outdir is not None:
+ outdir = make_outdir(
+ outdir,
+ args.overwrite,
+ half,
+ keep_ratio,
+ max_im_width,
+ target_size,
+ bin_value,
+ cloudy,
+ )
+
+ # -------------------------------
+ # ----- Create time store -----
+ # -------------------------------
+ stores = get_time_stores(import_time)
+
+ # -----------------------------------
+ # ----- Load Trainer instance -----
+ # -----------------------------------
+ with Timer(store=stores.get("setup", []), ignore=time_inference):
+ print("\n• Initializing trainer\n")
+ torch.set_grad_enabled(False)
+ trainer = Trainer.resume_from_path(
+ resume_path,
+ setup=True,
+ inference=True,
+ new_exp=None,
+ )
+ print()
+ print_num_parameters(trainer, True)
+ if fuse:
+ trainer.G = bn_fuse(trainer.G)
+ if half:
+ trainer.G.half()
+
+ # --------------------------------------------
+ # ----- Read data from input directory -----
+ # --------------------------------------------
+ print("\n• Reading & Pre-processing Data\n")
+
+ # find all images
+ data_paths = find_images(images_paths)
+ base_data_paths = data_paths
+ # filter images
+ if 0 < n_images < len(data_paths):
+ data_paths = data_paths[:n_images]
+ # repeat data
+ elif n_images > len(data_paths):
+ repeats = n_images // len(data_paths) + 1
+ data_paths = base_data_paths * repeats
+ data_paths = data_paths[:n_images]
+
+ with Timer(store=stores.get("data pre-processing", []), ignore=time_inference):
+ # read images to numpy arrays
+ data = [io.imread(str(d)) for d in data_paths]
+ # rgba to rgb
+ data = [im if im.shape[-1] == 3 else uint8(rgba2rgb(im) * 255) for im in data]
+ # resize images to target_size or
+ if keep_ratio:
+ # to closest multiples of 128 <= max_im_width, keeping aspect ratio
+ new_sizes = [to_128(d, max_im_width) for d in data]
+ data = [resize(d, ns, anti_aliasing=True) for d, ns in zip(data, new_sizes)]
+ else:
+ # to args.target_size
+ data = [resize_and_crop(d, target_size) for d in data]
+ new_sizes = [(target_size, target_size) for _ in data]
+ # resize() produces [0, 1] images, rescale to [-1, 1]
+ data = [to_m1_p1(d, i) for i, d in enumerate(data)]
+
+ n_batchs = len(data) // batch_size
+ if len(data) % batch_size != 0:
+ n_batchs += 1
+
+ print("Found", len(base_data_paths), "images. Inferring on", len(data), "images.")
+
+ # --------------------------------------------
+ # ----- Batch-process images to events -----
+ # --------------------------------------------
+ print(f"\n• Using device {str(trainer.device)}\n")
+
+ all_events = []
+
+ with Timer(store=stores.get("inference on all images", []), ignore=time_inference):
+ for b in tqdm(range(n_batchs), desc="Infering events", unit="batch"):
+
+ images = data[b * batch_size : (b + 1) * batch_size]
+ if not images:
+ continue
+
+ # concatenate images in a batch batch_size x height x width x 3
+ images = np.stack(images)
+ # Retreive numpy events as a dict {event: array[BxHxWxC]}
+ events = trainer.infer_all(
+ images,
+ numpy=True,
+ stores=stores,
+ bin_value=bin_value,
+ half=half,
+ cloudy=cloudy,
+ return_masks=save_masks,
+ )
+
+ # save resized and cropped image
+ if args.save_input:
+ events["input"] = uint8((images + 1) / 2 * 255)
+
+ # store events to write after inference loop
+ all_events.append(events)
+
+ # --------------------------------------------
+ # ----- Save (write/upload) inferences -----
+ # --------------------------------------------
+ if outdir is not None or upload:
+
+ if upload:
+ print("\n• Creating comet Experiment")
+ exp = comet_ml.Experiment(project_name="climategan-apply")
+ exp.log_parameters(vars(args))
+
+ # --------------------------------------------------------------
+ # ----- Change inferred data structure to a list of dicts -----
+ # --------------------------------------------------------------
+ to_write = []
+ events_names = list(all_events[0].keys())
+ for events_data in all_events:
+ n_ims = len(events_data[events_names[0]])
+ for i in range(n_ims):
+ item = {event: events_data[event][i] for event in events_names}
+ to_write.append(item)
+
+ progress_bar_desc = ""
+ if outdir is not None:
+ print("\n• Output directory:\n")
+ print(str(outdir), "\n")
+ if upload:
+ progress_bar_desc = "Writing & Uploading events"
+ else:
+ progress_bar_desc = "Writing events"
+ else:
+ if upload:
+ progress_bar_desc = "Uploading events"
+
+ # ------------------------------------
+ # ----- Save individual images -----
+ # ------------------------------------
+ with Timer(store=stores.get("write", []), ignore=time_inference):
+
+ # for each image
+ for t, event_dict in tqdm(
+ enumerate(to_write),
+ desc=progress_bar_desc,
+ unit="input image",
+ total=len(to_write),
+ ):
+
+ idx = t % len(base_data_paths)
+ stem = Path(data_paths[idx]).stem
+ width = new_sizes[idx][1]
+
+ if keep_ratio:
+ ar = "_AR"
+ else:
+ ar = ""
+
+ # for each event type
+ event_bar = tqdm(
+ enumerate(event_dict.items()),
+ leave=False,
+ total=len(events_names),
+ unit="event",
+ )
+ for e, (event, im_data) in event_bar:
+ event_bar.set_description(
+ f" {event.capitalize():<{len(progress_bar_desc) - 2}}"
+ )
+
+ if args.no_cloudy:
+ suffix = ar + "_no_cloudy"
+ else:
+ suffix = ar
+
+ im_path = Path(f"{stem}_{event}_{width}{suffix}.png")
+
+ if outdir is not None:
+ im_path = outdir / im_path
+ io.imsave(im_path, im_data)
+
+ if upload:
+ exp.log_image(im_data, name=im_path.name)
+ if zip_outdir:
+ print("\n• Zipping output directory... ", end="", flush=True)
+ archive_path = Path(shutil.make_archive(outdir.name, "zip", root_dir=outdir))
+ archive_path = archive_path.rename(outdir.parent / archive_path.name)
+ print("Done:\n")
+ print(str(archive_path))
+
+ # ---------------------------
+ # ----- Print timings -----
+ # ---------------------------
+ if time_inference:
+ print("\n• Timings\n")
+ print_store(stores)
+
+ # ---------------------------------------------
+ # ----- Save apply_events.py run config -----
+ # ---------------------------------------------
+ if not args.no_conf and outdir is not None:
+ write_apply_config(outdir)
diff --git a/climategan/__init__.py b/climategan/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..edfc9bec8573c946217947a2329afd7d2d05ec08
--- /dev/null
+++ b/climategan/__init__.py
@@ -0,0 +1,9 @@
+from importlib import import_module
+from pathlib import Path
+
+__all__ = [
+ import_module(f".{f.stem}", __package__)
+ for f in Path(__file__).parent.glob("*.py")
+ if "__" not in f.stem
+]
+del import_module, Path
diff --git a/climategan/blocks.py b/climategan/blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cab41f528d47c859fd17167092aabd6bcc359cd
--- /dev/null
+++ b/climategan/blocks.py
@@ -0,0 +1,398 @@
+"""File for all blocks which are parts of decoders
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import climategan.strings as strings
+from climategan.norms import SPADE, AdaptiveInstanceNorm2d, LayerNorm, SpectralNorm
+
+
+class InterpolateNearest2d(nn.Module):
+ """
+ Custom implementation of nn.Upsample because pytorch/xla
+ does not yet support scale_factor and needs to be provided with
+ the output_size
+ """
+
+ def __init__(self, scale_factor=2):
+ """
+ Create an InterpolateNearest2d module
+
+ Args:
+ scale_factor (int, optional): Output size multiplier. Defaults to 2.
+ """
+ super().__init__()
+ self.scale_factor = scale_factor
+
+ def forward(self, x):
+ """
+ Interpolate x in "nearest" mode on its last 2 dimensions
+
+ Args:
+ x (torch.Tensor): input to interpolate
+
+ Returns:
+ torch.Tensor: upsampled tensor with shape
+ (...x.shape, x.shape[-2] * scale_factor, x.shape[-1] * scale_factor)
+ """
+ return F.interpolate(
+ x,
+ size=(x.shape[-2] * self.scale_factor, x.shape[-1] * self.scale_factor),
+ mode="nearest",
+ )
+
+
+# -----------------------------------------
+# ----- Generic Convolutional Block -----
+# -----------------------------------------
+class Conv2dBlock(nn.Module):
+ def __init__(
+ self,
+ input_dim,
+ output_dim,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ norm="none",
+ activation="relu",
+ pad_type="zero",
+ bias=True,
+ ):
+ super().__init__()
+ self.use_bias = bias
+ # initialize padding
+ if pad_type == "reflect":
+ self.pad = nn.ReflectionPad2d(padding)
+ elif pad_type == "replicate":
+ self.pad = nn.ReplicationPad2d(padding)
+ elif pad_type == "zero":
+ self.pad = nn.ZeroPad2d(padding)
+ else:
+ assert 0, "Unsupported padding type: {}".format(pad_type)
+
+ # initialize normalization
+ use_spectral_norm = False
+ if norm.startswith("spectral_"):
+ norm = norm.replace("spectral_", "")
+ use_spectral_norm = True
+
+ norm_dim = output_dim
+ if norm == "batch":
+ self.norm = nn.BatchNorm2d(norm_dim)
+ elif norm == "instance":
+ # self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True)
+ self.norm = nn.InstanceNorm2d(norm_dim)
+ elif norm == "layer":
+ self.norm = LayerNorm(norm_dim)
+ elif norm == "adain":
+ self.norm = AdaptiveInstanceNorm2d(norm_dim)
+ elif norm == "spectral" or norm.startswith("spectral_"):
+ self.norm = None # dealt with later in the code
+ elif norm == "none":
+ self.norm = None
+ else:
+ raise ValueError("Unsupported normalization: {}".format(norm))
+
+ # initialize activation
+ if activation == "relu":
+ self.activation = nn.ReLU(inplace=False)
+ elif activation == "lrelu":
+ self.activation = nn.LeakyReLU(0.2, inplace=False)
+ elif activation == "prelu":
+ self.activation = nn.PReLU()
+ elif activation == "selu":
+ self.activation = nn.SELU(inplace=False)
+ elif activation == "tanh":
+ self.activation = nn.Tanh()
+ elif activation == "sigmoid":
+ self.activation = nn.Sigmoid()
+ elif activation == "none":
+ self.activation = None
+ else:
+ raise ValueError("Unsupported activation: {}".format(activation))
+
+ # initialize convolution
+ if norm == "spectral" or use_spectral_norm:
+ self.conv = SpectralNorm(
+ nn.Conv2d(
+ input_dim,
+ output_dim,
+ kernel_size,
+ stride,
+ dilation=dilation,
+ bias=self.use_bias,
+ )
+ )
+ else:
+ self.conv = nn.Conv2d(
+ input_dim,
+ output_dim,
+ kernel_size,
+ stride,
+ dilation=dilation,
+ bias=self.use_bias if norm != "batch" else False,
+ )
+
+ def forward(self, x):
+ x = self.conv(self.pad(x))
+ if self.norm is not None:
+ x = self.norm(x)
+ if self.activation is not None:
+ x = self.activation(x)
+ return x
+
+ def __str__(self):
+ return strings.conv2dblock(self)
+
+
+# -----------------------------
+# ----- Residual Blocks -----
+# -----------------------------
+class ResBlocks(nn.Module):
+ """
+ From https://github.com/NVlabs/MUNIT/blob/master/networks.py
+ """
+
+ def __init__(self, num_blocks, dim, norm="in", activation="relu", pad_type="zero"):
+ super().__init__()
+ self.model = nn.Sequential(
+ *[
+ ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)
+ for _ in range(num_blocks)
+ ]
+ )
+
+ def forward(self, x):
+ return self.model(x)
+
+ def __str__(self):
+ return strings.resblocks(self)
+
+
+class ResBlock(nn.Module):
+ def __init__(self, dim, norm="in", activation="relu", pad_type="zero"):
+ super().__init__()
+ self.dim = dim
+ self.norm = norm
+ self.activation = activation
+ model = []
+ model += [
+ Conv2dBlock(
+ dim, dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type
+ )
+ ]
+ model += [
+ Conv2dBlock(
+ dim, dim, 3, 1, 1, norm=norm, activation="none", pad_type=pad_type
+ )
+ ]
+ self.model = nn.Sequential(*model)
+
+ def forward(self, x):
+ residual = x
+ out = self.model(x)
+ out += residual
+ return out
+
+ def __str__(self):
+ return strings.resblock(self)
+
+
+# --------------------------
+# ----- Base Decoder -----
+# --------------------------
+class BaseDecoder(nn.Module):
+ def __init__(
+ self,
+ n_upsample=4,
+ n_res=4,
+ input_dim=2048,
+ proj_dim=64,
+ output_dim=3,
+ norm="batch",
+ activ="relu",
+ pad_type="zero",
+ output_activ="tanh",
+ low_level_feats_dim=-1,
+ use_dada=False,
+ ):
+ super().__init__()
+
+ self.low_level_feats_dim = low_level_feats_dim
+ self.use_dada = use_dada
+
+ self.model = []
+ if proj_dim != -1:
+ self.proj_conv = Conv2dBlock(
+ input_dim, proj_dim, 1, 1, 0, norm=norm, activation=activ
+ )
+ else:
+ self.proj_conv = None
+ proj_dim = input_dim
+
+ if low_level_feats_dim > 0:
+ self.low_level_conv = Conv2dBlock(
+ input_dim=low_level_feats_dim,
+ output_dim=proj_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ pad_type=pad_type,
+ norm=norm,
+ activation=activ,
+ )
+ self.merge_feats_conv = Conv2dBlock(
+ input_dim=2 * proj_dim,
+ output_dim=proj_dim,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ pad_type=pad_type,
+ norm=norm,
+ activation=activ,
+ )
+ else:
+ self.low_level_conv = None
+
+ self.model += [ResBlocks(n_res, proj_dim, norm, activ, pad_type=pad_type)]
+ dim = proj_dim
+ # upsampling blocks
+ for i in range(n_upsample):
+ self.model += [
+ InterpolateNearest2d(scale_factor=2),
+ Conv2dBlock(
+ input_dim=dim,
+ output_dim=dim // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ pad_type=pad_type,
+ norm=norm,
+ activation=activ,
+ ),
+ ]
+ dim //= 2
+ # use reflection padding in the last conv layer
+ self.model += [
+ Conv2dBlock(
+ input_dim=dim,
+ output_dim=output_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ pad_type=pad_type,
+ norm="none",
+ activation=output_activ,
+ )
+ ]
+ self.model = nn.Sequential(*self.model)
+
+ def forward(self, z, cond=None, z_depth=None):
+ low_level_feat = None
+ if isinstance(z, (list, tuple)):
+ if self.low_level_conv is None:
+ z = z[0]
+ else:
+ z, low_level_feat = z
+ low_level_feat = self.low_level_conv(low_level_feat)
+ low_level_feat = F.interpolate(
+ low_level_feat, size=z.shape[-2:], mode="bilinear"
+ )
+
+ if z_depth is not None and self.use_dada:
+ z = z * z_depth
+
+ if self.proj_conv is not None:
+ z = self.proj_conv(z)
+
+ if low_level_feat is not None:
+ z = self.merge_feats_conv(torch.cat([low_level_feat, z], dim=1))
+
+ return self.model(z)
+
+ def __str__(self):
+ return strings.basedecoder(self)
+
+
+# --------------------------
+# ----- SPADE Blocks -----
+# --------------------------
+# https://github.com/NVlabs/SPADE/blob/0ff661e70131c9b85091d11a66e019c0f2062d4c
+# /models/networks/generator.py
+# 0ff661e on 13 Apr 2019
+class SPADEResnetBlock(nn.Module):
+ def __init__(
+ self,
+ fin,
+ fout,
+ cond_nc,
+ spade_use_spectral_norm,
+ spade_param_free_norm,
+ spade_kernel_size,
+ last_activation=None,
+ ):
+ super().__init__()
+ # Attributes
+
+ self.fin = fin
+ self.fout = fout
+ self.use_spectral_norm = spade_use_spectral_norm
+ self.param_free_norm = spade_param_free_norm
+ self.kernel_size = spade_kernel_size
+
+ self.learned_shortcut = fin != fout
+ self.last_activation = last_activation
+ fmiddle = min(fin, fout)
+
+ # create conv layers
+ self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
+ self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
+ if self.learned_shortcut:
+ self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
+
+ # apply spectral norm if specified
+ if spade_use_spectral_norm:
+ self.conv_0 = SpectralNorm(self.conv_0)
+ self.conv_1 = SpectralNorm(self.conv_1)
+ if self.learned_shortcut:
+ self.conv_s = SpectralNorm(self.conv_s)
+
+ self.norm_0 = SPADE(spade_param_free_norm, spade_kernel_size, fin, cond_nc)
+ self.norm_1 = SPADE(spade_param_free_norm, spade_kernel_size, fmiddle, cond_nc)
+ if self.learned_shortcut:
+ self.norm_s = SPADE(spade_param_free_norm, spade_kernel_size, fin, cond_nc)
+
+ # note the resnet block with SPADE also takes in |seg|,
+ # the semantic segmentation map as input
+ def forward(self, x, seg):
+ x_s = self.shortcut(x, seg)
+
+ dx = self.conv_0(self.activation(self.norm_0(x, seg)))
+ dx = self.conv_1(self.activation(self.norm_1(dx, seg)))
+
+ out = x_s + dx
+ if self.last_activation == "lrelu":
+ return self.activation(out)
+ elif self.last_activation is None:
+ return out
+ else:
+ raise NotImplementedError(
+ "The type of activation is not supported: {}".format(
+ self.last_activation
+ )
+ )
+
+ def shortcut(self, x, seg):
+ if self.learned_shortcut:
+ x_s = self.conv_s(self.norm_s(x, seg))
+ else:
+ x_s = x
+ return x_s
+
+ def activation(self, x):
+ return F.leaky_relu(x, 2e-1)
+
+ def __str__(self):
+ return strings.spaderesblock(self)
diff --git a/climategan/bn_fusion.py b/climategan/bn_fusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..50c1171a927a21d76a89d691467005e7f419b296
--- /dev/null
+++ b/climategan/bn_fusion.py
@@ -0,0 +1,137 @@
+import torch
+from copy import deepcopy
+
+
+class FlattableModel(object):
+ def __init__(self, model):
+ self.model = deepcopy(model)
+ self._original_model = model
+ self._flat_model = None
+ self._attr_names = self.get_attributes_name()
+
+ def flatten_model(self):
+ if self._flat_model is None:
+ self._flat_model = self._flatten_model(self.model)
+ return self._flat_model
+
+ @staticmethod
+ def _selection_method(module):
+ return not (
+ isinstance(module, torch.nn.Sequential)
+ or isinstance(module, torch.nn.ModuleList)
+ ) and not hasattr(module, "_restricted")
+
+ @staticmethod
+ def _flatten_model(module):
+ modules = []
+ child = False
+ for (name, c) in module.named_children():
+ child = True
+ flattened_c = FlattableModel._flatten_model(c)
+ modules += flattened_c
+ if not child and FlattableModel._selection_method(module):
+ modules = [module]
+ return modules
+
+ def get_layer_io(self, layer, nb_samples, data_loader):
+ ios = []
+ hook = layer.register_forward_hook(
+ lambda m, i, o: ios.append((i[0].data.cpu(), o.data.cpu()))
+ )
+
+ nbatch = 1
+ for batch_idx, (xs, ys) in enumerate(data_loader):
+ # -1 takes all of them
+ if nb_samples != -1 and nbatch > nb_samples:
+ break
+ _ = self.model(xs.cuda())
+ nbatch += 1
+
+ hook.remove()
+ return ios
+
+ def get_attributes_name(self):
+ def _real_get_attributes_name(module):
+ modules = []
+ child = False
+ for (name, c) in module.named_children():
+ child = True
+ flattened_c = _real_get_attributes_name(c)
+ modules += map(lambda e: [name] + e, flattened_c)
+ if not child and FlattableModel._selection_method(module):
+ modules = [[]]
+ return modules
+
+ return _real_get_attributes_name(self.model)
+
+ def update_model(self, flat_model):
+ """
+ Take a list representing the flatten model and rebuild its internals.
+ :type flat_model: List[nn.Module]
+ """
+
+ def _apply_changes_on_layer(block, idxs, layer):
+ assert len(idxs) > 0
+ if len(idxs) == 1:
+ setattr(block, idxs[0], layer)
+ else:
+ _apply_changes_on_layer(getattr(block, idxs[0]), idxs[1:], layer)
+
+ def _apply_changes_model(model_list):
+ for i in range(len(model_list)):
+ _apply_changes_on_layer(self.model, self._attr_names[i], model_list[i])
+
+ _apply_changes_model(flat_model)
+ self._attr_names = self.get_attributes_name()
+ self._flat_model = None
+
+ def cuda(self):
+ self.model = self.model.cuda()
+ return self
+
+ def cpu(self):
+ self.model = self.model.cpu()
+ return self
+
+
+def bn_fuse(model):
+ model = model.cpu()
+ flattable = FlattableModel(model)
+ fmodel = flattable.flatten_model()
+
+ for index, item in enumerate(fmodel):
+ if (
+ isinstance(item, torch.nn.Conv2d)
+ and index + 1 < len(fmodel)
+ and isinstance(fmodel[index + 1], torch.nn.BatchNorm2d)
+ ):
+ alpha, beta = _calculate_alpha_beta(fmodel[index + 1])
+ if item.weight.shape[0] != alpha.shape[0]:
+ # this case happens if there was actually something else
+ # between the conv and the
+ # bn layer which is not picked up in flat model logic. (see densenet)
+ continue
+ item.weight.data = item.weight.data * alpha.view(-1, 1, 1, 1)
+ item.bias = torch.nn.Parameter(beta)
+ fmodel[index + 1] = _IdentityLayer()
+ flattable.update_model(fmodel)
+ return flattable.model
+
+
+def _calculate_alpha_beta(batchnorm_layer):
+ alpha = batchnorm_layer.weight.data / (
+ torch.sqrt(batchnorm_layer.running_var + batchnorm_layer.eps)
+ )
+ beta = (
+ -(batchnorm_layer.weight.data * batchnorm_layer.running_mean)
+ / (torch.sqrt(batchnorm_layer.running_var + batchnorm_layer.eps))
+ + batchnorm_layer.bias.data
+ )
+ alpha = alpha.cpu()
+ beta = beta.cpu()
+ return alpha, beta
+
+
+class _IdentityLayer(torch.nn.Module):
+ def forward(self, input):
+ return input
diff --git a/climategan/data.py b/climategan/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..e57fc21725361e02185088a909a74e36a8cd3fc4
--- /dev/null
+++ b/climategan/data.py
@@ -0,0 +1,539 @@
+"""Data-loading functions in order to create a Dataset and DataLoaders.
+Transforms for loaders are in transforms.py
+"""
+
+import json
+import os
+from pathlib import Path
+
+import numpy as np
+import torch
+import yaml
+from imageio import imread
+from PIL import Image
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+
+from climategan.transforms import get_transforms
+from climategan.tutils import get_normalized_depth_t
+from climategan.utils import env_to_path, is_image_file
+
+classes_dict = {
+ "s": { # unity
+ 0: [0, 0, 255, 255], # Water
+ 1: [55, 55, 55, 255], # Ground
+ 2: [0, 255, 255, 255], # Building
+ 3: [255, 212, 0, 255], # Traffic items
+ 4: [0, 255, 0, 255], # Vegetation
+ 5: [255, 97, 0, 255], # Terrain
+ 6: [255, 0, 0, 255], # Car
+ 7: [60, 180, 60, 255], # Trees
+ 8: [255, 0, 255, 255], # Person
+ 9: [0, 0, 0, 255], # Sky
+ 10: [255, 255, 255, 255], # Default
+ },
+ "r": { # deeplab v2
+ 0: [0, 0, 255, 255], # Water
+ 1: [55, 55, 55, 255], # Ground
+ 2: [0, 255, 255, 255], # Building
+ 3: [255, 212, 0, 255], # Traffic items
+ 4: [0, 255, 0, 255], # Vegetation
+ 5: [255, 97, 0, 255], # Terrain
+ 6: [255, 0, 0, 255], # Car
+ 7: [60, 180, 60, 255], # Trees
+ 8: [220, 20, 60, 255], # Person
+ 9: [8, 19, 49, 255], # Sky
+ 10: [0, 80, 100, 255], # Default
+ },
+ "kitti": {
+ 0: [210, 0, 200], # Terrain
+ 1: [90, 200, 255], # Sky
+ 2: [0, 199, 0], # Tree
+ 3: [90, 240, 0], # Vegetation
+ 4: [140, 140, 140], # Building
+ 5: [100, 60, 100], # Road
+ 6: [250, 100, 255], # GuardRail
+ 7: [255, 255, 0], # TrafficSign
+ 8: [200, 200, 0], # TrafficLight
+ 9: [255, 130, 0], # Pole
+ 10: [80, 80, 80], # Misc
+ 11: [160, 60, 60], # Truck
+ 12: [255, 127, 80], # Car
+ 13: [0, 139, 139], # Van
+ 14: [0, 0, 0], # Undefined
+ },
+ "flood": {
+ 0: [255, 0, 0], # Cannot flood
+ 1: [0, 0, 255], # Must flood
+ 2: [0, 0, 0], # May flood
+ },
+}
+
+kitti_mapping = {
+ 0: 5, # Terrain -> Terrain
+ 1: 9, # Sky -> Sky
+ 2: 7, # Tree -> Trees
+ 3: 4, # Vegetation -> Vegetation
+ 4: 2, # Building -> Building
+ 5: 1, # Road -> Ground
+ 6: 3, # GuardRail -> Traffic items
+ 7: 3, # TrafficSign -> Traffic items
+ 8: 3, # TrafficLight -> Traffic items
+ 9: 3, # Pole -> Traffic items
+ 10: 10, # Misc -> default
+ 11: 6, # Truck -> Car
+ 12: 6, # Car -> Car
+ 13: 6, # Van -> Car
+ 14: 10, # Undefined -> Default
+}
+
+
+def encode_exact_segmap(seg, classes_dict, default_value=14):
+ """
+ When the mapping (rgb -> label) is known to be exact (no approximative rgb values)
+ maps rgb image to segmap labels
+
+ Args:
+ seg (np.ndarray): H x W x 3 RGB image
+ classes_dict (dict): Mapping {class: rgb value}
+ default_value (int, optional): Value for unknown label. Defaults to 14.
+
+ Returns:
+ np.ndarray: Segmap as labels, not RGB
+ """
+ out = np.ones((seg.shape[0], seg.shape[1])) * default_value
+ for cindex, cvalue in classes_dict.items():
+ out[np.where((seg == cvalue).all(-1))] = cindex
+ return out
+
+
+def merge_labels(labels, mapping, default_value=14):
+ """
+ Maps labels from a source domain to labels of a target domain,
+ typically kitti -> climategan
+
+ Args:
+ labels (np.ndarray): input segmap labels
+ mapping (dict): source_label -> target_label
+ default_value (int, optional): Unknown label. Defaults to 14.
+
+ Returns:
+ np.ndarray: Adapted labels
+ """
+ out = np.ones_like(labels) * default_value
+ for source, target in mapping.items():
+ out[labels == source] = target
+ return out
+
+
+def process_kitti_seg(path, kitti_classes, merge_map, default=14):
+ """
+ Processes a path to produce a 1 x 1 x H x W torch segmap
+
+ %timeit process_kitti_seg(path, classes_dict, mapping, default=14)
+ 326 ms ± 118 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
+
+ Args:
+ path (str | pathlib.Path): Segmap RBG path
+ kitti_classes (dict): Kitti map label -> rgb
+ merge_map (dict): map kitti_label -> climategan_label
+ default (int, optional): Unknown kitti label. Defaults to 14.
+
+ Returns:
+ torch.Tensor: 1 x 1 x H x W torch segmap
+ """
+ seg = imread(path)
+ labels = encode_exact_segmap(seg, kitti_classes, default_value=default)
+ merged = merge_labels(labels, merge_map, default_value=default)
+ return torch.tensor(merged).unsqueeze(0).unsqueeze(0)
+
+
+def decode_segmap_merged_labels(tensor, domain, is_target, nc=11):
+ """Creates a label colormap for classes used in Unity segmentation benchmark.
+ Arguments:
+ tensor -- segmented image of size (1) x (nc) x (H) x (W)
+ if prediction, or size (1) x (1) x (H) x (W) if target
+ Returns:
+ RGB tensor of size (1) x (3) x (H) x (W)
+ #"""
+
+ if is_target: # Target is size 1 x 1 x H x W
+ idx = tensor.squeeze(0).squeeze(0)
+ else: # Prediction is size 1 x nc x H x W
+ idx = torch.argmax(tensor.squeeze(0), dim=0)
+
+ indexer = torch.tensor(list(classes_dict[domain].values()))[:, :3]
+ return indexer[idx.long()].permute(2, 0, 1).to(torch.float32).unsqueeze(0)
+
+
+def decode_segmap_cityscapes_labels(image, nc=19):
+ """Creates a label colormap used in CITYSCAPES segmentation benchmark.
+ Arguments:
+ image {array} -- segmented image
+ (array of image size containing class at each pixel)
+ Returns:
+ array of size 3*nc -- A colormap for visualizing segmentation results.
+ """
+ colormap = np.zeros((19, 3), dtype=np.uint8)
+ colormap[0] = [128, 64, 128]
+ colormap[1] = [244, 35, 232]
+ colormap[2] = [70, 70, 70]
+ colormap[3] = [102, 102, 156]
+ colormap[4] = [190, 153, 153]
+ colormap[5] = [153, 153, 153]
+ colormap[6] = [250, 170, 30]
+ colormap[7] = [220, 220, 0]
+ colormap[8] = [107, 142, 35]
+ colormap[9] = [152, 251, 152]
+ colormap[10] = [70, 130, 180]
+ colormap[11] = [220, 20, 60]
+ colormap[12] = [255, 0, 0]
+ colormap[13] = [0, 0, 142]
+ colormap[14] = [0, 0, 70]
+ colormap[15] = [0, 60, 100]
+ colormap[16] = [0, 80, 100]
+ colormap[17] = [0, 0, 230]
+ colormap[18] = [119, 11, 32]
+
+ r = np.zeros_like(image).astype(np.uint8)
+ g = np.zeros_like(image).astype(np.uint8)
+ b = np.zeros_like(image).astype(np.uint8)
+
+ for col in range(nc):
+ idx = image == col
+ r[idx] = colormap[col, 0]
+ g[idx] = colormap[col, 1]
+ b[idx] = colormap[col, 2]
+
+ rgb = np.stack([r, g, b], axis=2)
+ return rgb
+
+
+def find_closest_class(pixel, dict_classes):
+ """Takes a pixel as input and finds the closest known pixel value corresponding
+ to a class in dict_classes
+
+ Arguments:
+ pixel -- tuple pixel (R,G,B,A)
+ Returns:
+ tuple pixel (R,G,B,A) corresponding to a key (a class) in dict_classes
+ """
+ min_dist = float("inf")
+ closest_pixel = None
+ for pixel_value in dict_classes.keys():
+ dist = np.sqrt(np.sum(np.square(np.subtract(pixel, pixel_value))))
+ if dist < min_dist:
+ min_dist = dist
+ closest_pixel = pixel_value
+ return closest_pixel
+
+
+def encode_segmap(arr, domain):
+ """Change a segmentation RGBA array to a segmentation array
+ with each pixel being the index of the class
+ Arguments:
+ numpy array -- segmented image of size (H) x (W) x (4 RGBA values)
+ Returns:
+ numpy array of size (1) x (H) x (W) with each pixel being the index of the class
+ """
+ new_arr = np.zeros((1, arr.shape[0], arr.shape[1]))
+ dict_classes = {
+ tuple(rgba_value): class_id
+ for (class_id, rgba_value) in classes_dict[domain].items()
+ }
+ for i in range(arr.shape[0]):
+ for j in range(arr.shape[1]):
+ pixel_rgba = tuple(arr[i, j, :])
+ if pixel_rgba in dict_classes.keys():
+ new_arr[0, i, j] = dict_classes[pixel_rgba]
+ else:
+ pixel_rgba_closest = find_closest_class(pixel_rgba, dict_classes)
+ new_arr[0, i, j] = dict_classes[pixel_rgba_closest]
+ return new_arr
+
+
+def encode_mask_label(arr, domain):
+ """Change a segmentation RGBA array to a segmentation array
+ with each pixel being the index of the class
+ Arguments:
+ numpy array -- segmented image of size (H) x (W) x (3 RGB values)
+ Returns:
+ numpy array of size (1) x (H) x (W) with each pixel being the index of the class
+ """
+ diff = np.zeros((len(classes_dict[domain].keys()), arr.shape[0], arr.shape[1]))
+ for cindex, cvalue in classes_dict[domain].items():
+ diff[cindex, :, :] = np.sqrt(
+ np.sum(
+ np.square(arr - np.tile(cvalue, (arr.shape[0], arr.shape[1], 1))),
+ axis=2,
+ )
+ )
+ return np.expand_dims(np.argmin(diff, axis=0), axis=0)
+
+
+def transform_segmap_image_to_tensor(path, domain):
+ """
+ Transforms a segmentation image to a tensor of size (1) x (1) x (H) x (W)
+ with each pixel being the index of the class
+ """
+ arr = np.array(Image.open(path).convert("RGBA"))
+ arr = encode_segmap(arr, domain)
+ arr = torch.from_numpy(arr).float()
+ arr = arr.unsqueeze(0)
+ return arr
+
+
+def save_segmap_tensors(path_to_json, path_to_dir, domain):
+ """
+ Loads the segmentation images mentionned in a json file, transforms them to
+ tensors and save the tensors in the wanted directory
+
+ Args:
+ path_to_json: complete path to the json file where to find the original data
+ path_to_dir: path to the directory where to save the tensors as tensor_name.pt
+ domain: domain of the images ("r" or "s")
+
+ e.g:
+ save_tensors(
+ "/network/tmp1/ccai/data/climategan/seg/train_s.json",
+ "/network/tmp1/ccai/data/munit_dataset/simdata/Unity11K_res640/Seg_tensors/",
+ "s",
+ )
+ """
+ ims_list = None
+ if path_to_json:
+ path_to_json = Path(path_to_json).resolve()
+ with open(path_to_json, "r") as f:
+ ims_list = yaml.safe_load(f)
+
+ assert ims_list is not None
+
+ for im_dict in ims_list:
+ for task_name, path in im_dict.items():
+ if task_name == "s":
+ file_name = os.path.splitext(path)[0] # remove extension
+ file_name = file_name.rsplit("/", 1)[-1] # keep only the file_name
+ tensor = transform_segmap_image_to_tensor(path, domain)
+ torch.save(tensor, path_to_dir + file_name + ".pt")
+
+
+def pil_image_loader(path, task):
+ if Path(path).suffix == ".npy":
+ arr = np.load(path).astype(np.uint8)
+ elif is_image_file(path):
+ # arr = imread(path).astype(np.uint8)
+ arr = np.array(Image.open(path).convert("RGB"))
+ else:
+ raise ValueError("Unknown data type {}".format(path))
+
+ # Convert from RGBA to RGB for images
+ if len(arr.shape) == 3 and arr.shape[-1] == 4:
+ arr = arr[:, :, 0:3]
+
+ if task == "m":
+ arr[arr != 0] = 1
+ # Make sure mask is single-channel
+ if len(arr.shape) >= 3:
+ arr = arr[:, :, 0]
+
+ # assert len(arr.shape) == 3, (path, task, arr.shape)
+
+ return Image.fromarray(arr)
+
+
+def tensor_loader(path, task, domain, opts):
+ """load data as tensors
+ Args:
+ path (str): path to data
+ task (str)
+ domain (str)
+ Returns:
+ [Tensor]: 1 x C x H x W
+ """
+ if task == "s":
+ if domain == "kitti":
+ return process_kitti_seg(
+ path, classes_dict["kitti"], kitti_mapping, default=14
+ )
+ return torch.load(path)
+ elif task == "d":
+ if Path(path).suffix == ".npy":
+ arr = np.load(path)
+ else:
+ arr = imread(path) # .astype(np.uint8) /!\ kitti is np.uint16
+ tensor = torch.from_numpy(arr.astype(np.float32))
+ tensor = get_normalized_depth_t(
+ tensor,
+ domain,
+ normalize="d" in opts.train.pseudo.tasks,
+ log=opts.gen.d.classify.enable,
+ )
+ tensor = tensor.unsqueeze(0)
+ return tensor
+
+ elif Path(path).suffix == ".npy":
+ arr = np.load(path).astype(np.float32)
+ elif is_image_file(path):
+ arr = imread(path).astype(np.float32)
+ else:
+ raise ValueError("Unknown data type {}".format(path))
+
+ # Convert from RGBA to RGB for images
+ if len(arr.shape) == 3 and arr.shape[-1] == 4:
+ arr = arr[:, :, 0:3]
+
+ if task == "x":
+ arr -= arr.min()
+ arr /= arr.max()
+ arr = np.moveaxis(arr, 2, 0)
+ elif task == "s":
+ arr = np.moveaxis(arr, 2, 0)
+ elif task == "m":
+ if arr.max() > 127:
+ arr = (arr > 127).astype(arr.dtype)
+ # Make sure mask is single-channel
+ if len(arr.shape) >= 3:
+ arr = arr[:, :, 0]
+ arr = np.expand_dims(arr, 0)
+
+ return torch.from_numpy(arr).unsqueeze(0)
+
+
+class OmniListDataset(Dataset):
+ def __init__(self, mode, domain, opts, transform=None):
+
+ self.opts = opts
+ self.domain = domain
+ self.mode = mode
+ self.tasks = set(opts.tasks)
+ self.tasks.add("x")
+ if "p" in self.tasks:
+ self.tasks.add("m")
+
+ file_list_path = Path(opts.data.files[mode][domain])
+ if "/" not in str(file_list_path):
+ file_list_path = Path(opts.data.files.base) / Path(
+ opts.data.files[mode][domain]
+ )
+
+ if file_list_path.suffix == ".json":
+ self.samples_paths = self.json_load(file_list_path)
+ elif file_list_path.suffix in {".yaml", ".yml"}:
+ self.samples_paths = self.yaml_load(file_list_path)
+ else:
+ raise ValueError("Unknown file list type in {}".format(file_list_path))
+
+ if opts.data.max_samples and opts.data.max_samples != -1:
+ assert isinstance(opts.data.max_samples, int)
+ self.samples_paths = self.samples_paths[: opts.data.max_samples]
+
+ self.filter_samples()
+ if opts.data.check_samples:
+ print(f"Checking samples ({mode}, {domain})")
+ self.check_samples()
+ self.file_list_path = str(file_list_path)
+ self.transform = transform
+
+ def filter_samples(self):
+ """
+ Filter out data which is not required for the model's tasks
+ as defined in opts.tasks
+ """
+ self.samples_paths = [
+ {k: v for k, v in s.items() if k in self.tasks} for s in self.samples_paths
+ ]
+
+ def __getitem__(self, i):
+ """Return an item in the dataset with fields:
+ {
+ data: transform({
+ domains: values
+ }),
+ paths: [{task: path}],
+ domain: [domain],
+ mode: [train|val]
+ }
+ Args:
+ i (int): index of item to retrieve
+ Returns:
+ dict: dataset item where tensors of data are in item["data"] which is a dict
+ {task: tensor}
+ """
+ paths = self.samples_paths[i]
+
+ # always apply transforms,
+ # if no transform is specified, ToTensor and Normalize will be applied
+
+ item = {
+ "data": self.transform(
+ {
+ task: tensor_loader(
+ env_to_path(path),
+ task,
+ self.domain,
+ self.opts,
+ )
+ for task, path in paths.items()
+ }
+ ),
+ "paths": paths,
+ "domain": self.domain if self.domain != "kitti" else "s",
+ "mode": self.mode,
+ }
+
+ return item
+
+ def __len__(self):
+ return len(self.samples_paths)
+
+ def json_load(self, file_path):
+ with open(file_path, "r") as f:
+ return json.load(f)
+
+ def yaml_load(self, file_path):
+ with open(file_path, "r") as f:
+ return yaml.safe_load(f)
+
+ def check_samples(self):
+ """Checks that every file listed in samples_paths actually
+ exist on the file-system
+ """
+ for s in self.samples_paths:
+ for k, v in s.items():
+ assert Path(v).exists(), f"{k} {v} does not exist"
+
+
+def get_loader(mode, domain, opts):
+ if (
+ domain != "kitti"
+ or not opts.train.kitti.pretrain
+ or not opts.train.kitti.batch_size
+ ):
+ batch_size = opts.data.loaders.get("batch_size", 4)
+ else:
+ batch_size = opts.train.kitti.get("batch_size", 4)
+
+ return DataLoader(
+ OmniListDataset(
+ mode,
+ domain,
+ opts,
+ transform=transforms.Compose(get_transforms(opts, mode, domain)),
+ ),
+ batch_size=batch_size,
+ shuffle=True,
+ num_workers=opts.data.loaders.get("num_workers", 8),
+ pin_memory=True, # faster transfer to gpu
+ drop_last=True, # avoids batchnorm pbs if last batch has size 1
+ )
+
+
+def get_all_loaders(opts):
+ loaders = {}
+ for mode in ["train", "val"]:
+ loaders[mode] = {}
+ for domain in opts.domains:
+ if mode in opts.data.files:
+ if domain in opts.data.files[mode]:
+ loaders[mode][domain] = get_loader(mode, domain, opts)
+ return loaders
diff --git a/climategan/deeplab/__init__.py b/climategan/deeplab/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..74eb60572cd4fd5a2fd6ec4247a757c665e75f21
--- /dev/null
+++ b/climategan/deeplab/__init__.py
@@ -0,0 +1,101 @@
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+from climategan.deeplab.deeplab_v2 import DeepLabV2Decoder
+from climategan.deeplab.deeplab_v3 import DeepLabV3Decoder
+from climategan.deeplab.mobilenet_v3 import MobileNetV2
+from climategan.deeplab.resnet101_v3 import ResNet101
+from climategan.deeplab.resnetmulti_v2 import ResNetMulti
+
+
+def create_encoder(opts, no_init=False, verbose=0):
+ if opts.gen.encoder.architecture == "deeplabv2":
+ if verbose > 0:
+ print(" - Add Deeplabv2 Encoder")
+ return DeeplabV2Encoder(opts, no_init, verbose)
+ elif opts.gen.encoder.architecture == "deeplabv3":
+ if verbose > 0:
+ backone = opts.gen.deeplabv3.backbone
+ print(" - Add Deeplabv3 ({}) Encoder".format(backone))
+ return build_v3_backbone(opts, no_init)
+ else:
+ raise NotImplementedError(
+ "Unknown encoder: {}".format(opts.gen.encoder.architecture)
+ )
+
+
+def create_segmentation_decoder(opts, no_init=False, verbose=0):
+ if opts.gen.s.architecture == "deeplabv2":
+ if verbose > 0:
+ print(" - Add DeepLabV2Decoder")
+ return DeepLabV2Decoder(opts)
+ elif opts.gen.s.architecture == "deeplabv3":
+ if verbose > 0:
+ print(" - Add DeepLabV3Decoder")
+ return DeepLabV3Decoder(opts, no_init)
+ else:
+ raise NotImplementedError(
+ "Unknown Segmentation architecture: {}".format(opts.gen.s.architecture)
+ )
+
+
+def build_v3_backbone(opts, no_init, verbose=0):
+ backbone = opts.gen.deeplabv3.backbone
+ output_stride = opts.gen.deeplabv3.output_stride
+ if backbone == "resnet":
+ resnet = ResNet101(
+ output_stride=output_stride,
+ BatchNorm=nn.BatchNorm2d,
+ verbose=verbose,
+ no_init=no_init,
+ )
+ if not no_init:
+ if opts.gen.deeplabv3.backbone == "resnet":
+ assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists()
+
+ std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet)
+ resnet.load_state_dict(
+ {
+ k.replace("backbone.", ""): v
+ for k, v in std.items()
+ if k.startswith("backbone.")
+ }
+ )
+ print(
+ " - Loaded pre-trained DeepLabv3+ Resnet101 Backbone as Encoder"
+ )
+ return resnet
+
+ elif opts.gen.deeplabv3.backbone == "mobilenet":
+ assert Path(opts.gen.deeplabv3.pretrained_model.mobilenet).exists()
+ mobilenet = MobileNetV2(
+ no_init=no_init,
+ pretrained_path=opts.gen.deeplabv3.pretrained_model.mobilenet,
+ )
+ print(" - Loaded pre-trained DeepLabv3+ MobileNetV2 Backbone as Encoder")
+ return mobilenet
+
+ else:
+ raise NotImplementedError("Unknown backbone in " + str(opts.gen.deeplabv3))
+
+
+class DeeplabV2Encoder(nn.Module):
+ def __init__(self, opts, no_init=False, verbose=0):
+ """Deeplab architecture encoder"""
+ super().__init__()
+
+ self.model = ResNetMulti(opts.gen.deeplabv2.nblocks, opts.gen.encoder.n_res)
+ if opts.gen.deeplabv2.use_pretrained and not no_init:
+ saved_state_dict = torch.load(opts.gen.deeplabv2.pretrained_model)
+ new_params = self.model.state_dict().copy()
+ for i in saved_state_dict:
+ i_parts = i.split(".")
+ if not i_parts[1] in ["layer5", "resblock"]:
+ new_params[".".join(i_parts[1:])] = saved_state_dict[i]
+ self.model.load_state_dict(new_params)
+ if verbose > 0:
+ print(" - Loaded pretrained weights")
+
+ def forward(self, x):
+ return self.model(x)
diff --git a/climategan/deeplab/deeplab_v2.py b/climategan/deeplab/deeplab_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c097ccf4181588f2f814d2baa564a2213899450
--- /dev/null
+++ b/climategan/deeplab/deeplab_v2.py
@@ -0,0 +1,198 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from climategan.blocks import InterpolateNearest2d
+from climategan.utils import find_target_size
+
+
+class _ASPPModule(nn.Module):
+ # https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/aspp.py
+ def __init__(
+ self, inplanes, planes, kernel_size, padding, dilation, BatchNorm, no_init
+ ):
+ super().__init__()
+ self.atrous_conv = nn.Conv2d(
+ inplanes,
+ planes,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding,
+ dilation=dilation,
+ bias=False,
+ )
+ self.bn = BatchNorm(planes)
+ self.relu = nn.ReLU()
+ if not no_init:
+ self._init_weight()
+
+ def forward(self, x):
+ x = self.atrous_conv(x)
+ x = self.bn(x)
+
+ return self.relu(x)
+
+ def _init_weight(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ torch.nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+
+class ASPP(nn.Module):
+ # https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/aspp.py
+ def __init__(self, backbone, output_stride, BatchNorm, no_init):
+ super().__init__()
+
+ if backbone == "mobilenet":
+ inplanes = 320
+ else:
+ inplanes = 2048
+
+ if output_stride == 16:
+ dilations = [1, 6, 12, 18]
+ elif output_stride == 8:
+ dilations = [1, 12, 24, 36]
+ else:
+ raise NotImplementedError
+
+ self.aspp1 = _ASPPModule(
+ inplanes,
+ 256,
+ 1,
+ padding=0,
+ dilation=dilations[0],
+ BatchNorm=BatchNorm,
+ no_init=no_init,
+ )
+ self.aspp2 = _ASPPModule(
+ inplanes,
+ 256,
+ 3,
+ padding=dilations[1],
+ dilation=dilations[1],
+ BatchNorm=BatchNorm,
+ no_init=no_init,
+ )
+ self.aspp3 = _ASPPModule(
+ inplanes,
+ 256,
+ 3,
+ padding=dilations[2],
+ dilation=dilations[2],
+ BatchNorm=BatchNorm,
+ no_init=no_init,
+ )
+ self.aspp4 = _ASPPModule(
+ inplanes,
+ 256,
+ 3,
+ padding=dilations[3],
+ dilation=dilations[3],
+ BatchNorm=BatchNorm,
+ no_init=no_init,
+ )
+
+ self.global_avg_pool = nn.Sequential(
+ nn.AdaptiveAvgPool2d((1, 1)),
+ nn.Conv2d(inplanes, 256, 1, stride=1, bias=False),
+ BatchNorm(256),
+ nn.ReLU(),
+ )
+ self.conv1 = nn.Conv2d(1280, 256, 1, bias=False)
+ self.bn1 = BatchNorm(256)
+ self.relu = nn.ReLU()
+ self.dropout = nn.Dropout(0.5)
+ if not no_init:
+ self._init_weight()
+
+ def forward(self, x):
+ x1 = self.aspp1(x)
+ x2 = self.aspp2(x)
+ x3 = self.aspp3(x)
+ x4 = self.aspp4(x)
+ x5 = self.global_avg_pool(x)
+ x5 = F.interpolate(x5, size=x4.size()[2:], mode="bilinear", align_corners=True)
+ x = torch.cat((x1, x2, x3, x4, x5), dim=1)
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+
+ return self.dropout(x)
+
+ def _init_weight(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
+ # m.weight.data.normal_(0, math.sqrt(2. / n))
+ torch.nn.init.kaiming_normal_(m.weight)
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+
+
+class DeepLabV2Decoder(nn.Module):
+ # https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/decoder.py
+ # https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/deeplab.py
+ def __init__(self, opts, no_init=False):
+ super().__init__()
+ self.aspp = ASPP("resnet", 16, nn.BatchNorm2d, no_init)
+ self.use_dada = ("d" in opts.tasks) and opts.gen.s.use_dada
+
+ conv_modules = [
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(256),
+ nn.ReLU(),
+ nn.Dropout(0.5),
+ nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
+ nn.BatchNorm2d(256),
+ nn.ReLU(),
+ nn.Dropout(0.1),
+ ]
+ if opts.gen.s.upsample_featuremaps:
+ conv_modules = [InterpolateNearest2d(scale_factor=2)] + conv_modules
+
+ conv_modules += [
+ nn.Conv2d(256, opts.gen.s.output_dim, kernel_size=1, stride=1),
+ ]
+ self.conv = nn.Sequential(*conv_modules)
+
+ self._target_size = find_target_size(opts, "s")
+ print(
+ " - {}: setting target size to {}".format(
+ self.__class__.__name__, self._target_size
+ )
+ )
+
+ def set_target_size(self, size):
+ """
+ Set final interpolation's target size
+
+ Args:
+ size (int, list, tuple): target size (h, w). If int, target will be (i, i)
+ """
+ if isinstance(size, (list, tuple)):
+ self._target_size = size[:2]
+ else:
+ self._target_size = (size, size)
+
+ def forward(self, z, z_depth=None):
+ if self._target_size is None:
+ error = "self._target_size should be set with self.set_target_size()"
+ error += "to interpolate logits to the target seg map's size"
+ raise Exception(error)
+ if isinstance(z, (list, tuple)):
+ z = z[0]
+ if z.shape[1] != 2048:
+ raise Exception(
+ "Segmentation decoder will only work with 2048 channels for z"
+ )
+
+ if z_depth is not None and self.use_dada:
+ z = z * z_depth
+
+ y = self.aspp(z)
+ y = self.conv(y)
+ return F.interpolate(y, self._target_size, mode="bilinear", align_corners=True)
diff --git a/climategan/deeplab/deeplab_v3.py b/climategan/deeplab/deeplab_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..041b567455ceb0aaf075a453eff3c0560471e2d4
--- /dev/null
+++ b/climategan/deeplab/deeplab_v3.py
@@ -0,0 +1,271 @@
+"""
+https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/backbone/resnet.py
+"""
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from climategan.deeplab.mobilenet_v3 import SeparableConv2d
+from climategan.utils import find_target_size
+
+
+class _DeepLabHead(nn.Module):
+ def __init__(
+ self, nclass, c1_channels=256, c4_channels=2048, norm_layer=nn.BatchNorm2d
+ ):
+ super().__init__()
+ last_channels = c4_channels
+ # self.c1_block = _ConvBNReLU(c1_channels, 48, 1, norm_layer=norm_layer)
+ # last_channels += 48
+ self.block = nn.Sequential(
+ SeparableConv2d(
+ last_channels, 256, 3, norm_layer=norm_layer, relu_first=False
+ ),
+ SeparableConv2d(256, 256, 3, norm_layer=norm_layer, relu_first=False),
+ nn.Conv2d(256, nclass, 1),
+ )
+
+ def forward(self, x, c1=None):
+ return self.block(x)
+
+
+class ConvBNReLU(nn.Module):
+ """
+ https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
+ """
+
+ def __init__(
+ self, in_chan, out_chan, ks=3, stride=1, padding=1, dilation=1, *args, **kwargs
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(
+ in_chan,
+ out_chan,
+ kernel_size=ks,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=True,
+ )
+ self.bn = nn.BatchNorm2d(out_chan)
+ self.init_weight()
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+ def init_weight(self):
+ for ly in self.children():
+ if isinstance(ly, nn.Conv2d):
+ nn.init.kaiming_normal_(ly.weight, a=1)
+ if ly.bias is not None:
+ nn.init.constant_(ly.bias, 0)
+
+
+class ASPPv3Plus(nn.Module):
+ """
+ https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
+ """
+
+ def __init__(self, backbone, no_init):
+ super().__init__()
+
+ if backbone == "mobilenet":
+ in_chan = 320
+ else:
+ in_chan = 2048
+
+ self.with_gp = False
+ self.conv1 = ConvBNReLU(in_chan, 256, ks=1, dilation=1, padding=0)
+ self.conv2 = ConvBNReLU(in_chan, 256, ks=3, dilation=6, padding=6)
+ self.conv3 = ConvBNReLU(in_chan, 256, ks=3, dilation=12, padding=12)
+ self.conv4 = ConvBNReLU(in_chan, 256, ks=3, dilation=18, padding=18)
+ if self.with_gp:
+ self.avg = nn.AdaptiveAvgPool2d((1, 1))
+ self.conv1x1 = ConvBNReLU(in_chan, 256, ks=1)
+ self.conv_out = ConvBNReLU(256 * 5, 256, ks=1)
+ else:
+ self.conv_out = ConvBNReLU(256 * 4, 256, ks=1)
+
+ if not no_init:
+ self.init_weight()
+
+ def forward(self, x):
+ H, W = x.size()[2:]
+ feat1 = self.conv1(x)
+ feat2 = self.conv2(x)
+ feat3 = self.conv3(x)
+ feat4 = self.conv4(x)
+ if self.with_gp:
+ avg = self.avg(x)
+ feat5 = self.conv1x1(avg)
+ feat5 = F.interpolate(feat5, (H, W), mode="bilinear", align_corners=True)
+ feat = torch.cat([feat1, feat2, feat3, feat4, feat5], 1)
+ else:
+ feat = torch.cat([feat1, feat2, feat3, feat4], 1)
+ feat = self.conv_out(feat)
+ return feat
+
+ def init_weight(self):
+ for ly in self.children():
+ if isinstance(ly, nn.Conv2d):
+ nn.init.kaiming_normal_(ly.weight, a=1)
+ if ly.bias is not None:
+ nn.init.constant_(ly.bias, 0)
+
+
+class Decoder(nn.Module):
+ """
+ https://github.com/CoinCheung/DeepLab-v3-plus-cityscapes/blob/master/models/deeplabv3plus.py
+ """
+
+ def __init__(self, n_classes):
+ super(Decoder, self).__init__()
+ self.conv_low = ConvBNReLU(256, 48, ks=1, padding=0)
+ self.conv_cat = nn.Sequential(
+ ConvBNReLU(304, 256, ks=3, padding=1),
+ ConvBNReLU(256, 256, ks=3, padding=1),
+ )
+ self.conv_out = nn.Conv2d(256, n_classes, kernel_size=1, bias=False)
+
+ def forward(self, feat_low, feat_aspp):
+ H, W = feat_low.size()[2:]
+ feat_low = self.conv_low(feat_low)
+ feat_aspp_up = F.interpolate(
+ feat_aspp, (H, W), mode="bilinear", align_corners=True
+ )
+ feat_cat = torch.cat([feat_low, feat_aspp_up], dim=1)
+ feat_out = self.conv_cat(feat_cat)
+ logits = self.conv_out(feat_out)
+ return logits
+
+
+"""
+https://github.com/jfzhang95/pytorch-deeplab-xception/blob/master/modeling/deeplab.py
+"""
+
+
+class DeepLabV3Decoder(nn.Module):
+ def __init__(
+ self,
+ opts,
+ no_init=False,
+ freeze_bn=False,
+ ):
+ super().__init__()
+
+ num_classes = opts.gen.s.output_dim
+ self.backbone = opts.gen.deeplabv3.backbone
+ self.use_dada = ("d" in opts.tasks) and opts.gen.s.use_dada
+
+ if self.backbone == "resnet":
+ self.aspp = ASPPv3Plus(self.backbone, no_init)
+ self.decoder = Decoder(num_classes)
+
+ self.freeze_bn = freeze_bn
+ else:
+ self.head = _DeepLabHead(num_classes, c4_channels=320)
+
+ self._target_size = find_target_size(opts, "s")
+ print(
+ " - {}: setting target size to {}".format(
+ self.__class__.__name__, self._target_size
+ )
+ )
+
+ if not no_init:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ nn.init.zeros_(m.bias)
+
+ self.load_pretrained(opts)
+
+ def load_pretrained(self, opts):
+ assert opts.gen.deeplabv3.backbone in {"resnet", "mobilenet"}
+ assert Path(opts.gen.deeplabv3.pretrained_model.resnet).exists()
+ if opts.gen.deeplabv3.backbone == "resnet":
+ std = torch.load(opts.gen.deeplabv3.pretrained_model.resnet)
+ self.aspp.load_state_dict(
+ {
+ k.replace("aspp.", ""): v
+ for k, v in std.items()
+ if k.startswith("aspp.")
+ }
+ )
+ self.decoder.load_state_dict(
+ {
+ k.replace("decoder.", ""): v
+ for k, v in std.items()
+ if k.startswith("decoder.")
+ and not (len(v.shape) > 0 and v.shape[0] == 19)
+ },
+ strict=False,
+ )
+ print(
+ "- Loaded pre-trained DeepLabv3+ (Resnet) Decoder & ASPP as Seg Decoder"
+ )
+ else:
+ std = torch.load(opts.gen.deeplabv3.pretrained_model.mobilenet)
+ self.load_state_dict(
+ {
+ k: v
+ for k, v in std.items()
+ if k.startswith("head.")
+ and not (len(v.shape) > 0 and v.shape[0] == 19)
+ },
+ strict=False,
+ )
+ print(
+ " - Loaded pre-trained DeepLabv3+ (MobileNetV2) Head as Seg Decoder"
+ )
+
+ def set_target_size(self, size):
+ """
+ Set final interpolation's target size
+
+ Args:
+ size (int, list, tuple): target size (h, w). If int, target will be (i, i)
+ """
+ if isinstance(size, (list, tuple)):
+ self._target_size = size[:2]
+ else:
+ self._target_size = (size, size)
+
+ def forward(self, z, z_depth=None):
+ assert isinstance(z, (tuple, list))
+ if self._target_size is None:
+ error = "self._target_size should be set with self.set_target_size()"
+ error += "to interpolate logits to the target seg map's size"
+ raise ValueError(error)
+
+ z_high, z_low = z
+
+ if z_depth is not None and self.use_dada:
+ z_high = z_high * z_depth
+
+ if self.backbone == "resnet":
+ z_high = self.aspp(z_high)
+ s = self.decoder(z_high, z_low)
+ else:
+ s = self.head(z_high)
+
+ s = F.interpolate(
+ s, size=self._target_size, mode="bilinear", align_corners=True
+ )
+
+ return s
+
+ def freeze_bn(self):
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
diff --git a/climategan/deeplab/mobilenet_v3.py b/climategan/deeplab/mobilenet_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ba08e60a2ab529627c93505b9a2cb81522ab518
--- /dev/null
+++ b/climategan/deeplab/mobilenet_v3.py
@@ -0,0 +1,324 @@
+"""
+from https://github.com/LikeLy-Journey/SegmenTron/blob/
+4bc605eedde7d680314f63d329277b73f83b1c5f/segmentron/modules/basic.py#L34
+"""
+
+from collections import OrderedDict
+from pathlib import Path
+
+import torch
+import torch.nn as nn
+from climategan.blocks import InterpolateNearest2d
+
+
+class SeparableConv2d(nn.Module):
+ def __init__(
+ self,
+ inplanes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ dilation=1,
+ relu_first=True,
+ bias=False,
+ norm_layer=nn.BatchNorm2d,
+ ):
+ super().__init__()
+ depthwise = nn.Conv2d(
+ inplanes,
+ inplanes,
+ kernel_size,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ groups=inplanes,
+ bias=bias,
+ )
+ bn_depth = norm_layer(inplanes)
+ pointwise = nn.Conv2d(inplanes, planes, 1, bias=bias)
+ bn_point = norm_layer(planes)
+
+ if relu_first:
+ self.block = nn.Sequential(
+ OrderedDict(
+ [
+ ("relu", nn.ReLU()),
+ ("depthwise", depthwise),
+ ("bn_depth", bn_depth),
+ ("pointwise", pointwise),
+ ("bn_point", bn_point),
+ ]
+ )
+ )
+ else:
+ self.block = nn.Sequential(
+ OrderedDict(
+ [
+ ("depthwise", depthwise),
+ ("bn_depth", bn_depth),
+ ("relu1", nn.ReLU(inplace=True)),
+ ("pointwise", pointwise),
+ ("bn_point", bn_point),
+ ("relu2", nn.ReLU(inplace=True)),
+ ]
+ )
+ )
+
+ def forward(self, x):
+ return self.block(x)
+
+
+class _ConvBNReLU(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ relu6=False,
+ norm_layer=nn.BatchNorm2d,
+ ):
+ super(_ConvBNReLU, self).__init__()
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups,
+ bias=False,
+ )
+ self.bn = norm_layer(out_channels)
+ self.relu = nn.ReLU6(True) if relu6 else nn.ReLU(True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ x = self.relu(x)
+ return x
+
+
+class _DepthwiseConv(nn.Module):
+ """conv_dw in MobileNet"""
+
+ def __init__(
+ self, in_channels, out_channels, stride, norm_layer=nn.BatchNorm2d, **kwargs
+ ):
+ super(_DepthwiseConv, self).__init__()
+ self.conv = nn.Sequential(
+ _ConvBNReLU(
+ in_channels,
+ in_channels,
+ 3,
+ stride,
+ 1,
+ groups=in_channels,
+ norm_layer=norm_layer,
+ ),
+ _ConvBNReLU(in_channels, out_channels, 1, norm_layer=norm_layer),
+ )
+
+ def forward(self, x):
+ return self.conv(x)
+
+
+class InvertedResidual(nn.Module):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ stride,
+ expand_ratio,
+ dilation=1,
+ norm_layer=nn.BatchNorm2d,
+ ):
+ super(InvertedResidual, self).__init__()
+ assert stride in [1, 2]
+ self.use_res_connect = stride == 1 and in_channels == out_channels
+
+ layers = list()
+ inter_channels = int(round(in_channels * expand_ratio))
+ if expand_ratio != 1:
+ # pw
+ layers.append(
+ _ConvBNReLU(
+ in_channels, inter_channels, 1, relu6=True, norm_layer=norm_layer
+ )
+ )
+ layers.extend(
+ [
+ # dw
+ _ConvBNReLU(
+ inter_channels,
+ inter_channels,
+ 3,
+ stride,
+ dilation,
+ dilation,
+ groups=inter_channels,
+ relu6=True,
+ norm_layer=norm_layer,
+ ),
+ # pw-linear
+ nn.Conv2d(inter_channels, out_channels, 1, bias=False),
+ norm_layer(out_channels),
+ ]
+ )
+ self.conv = nn.Sequential(*layers)
+
+ def forward(self, x):
+ if self.use_res_connect:
+ return x + self.conv(x)
+ else:
+ return self.conv(x)
+
+
+class MobileNetV2(nn.Module):
+ def __init__(self, norm_layer=nn.BatchNorm2d, pretrained_path=None, no_init=False):
+ super(MobileNetV2, self).__init__()
+ output_stride = 16
+ self.multiplier = 1.0
+ if output_stride == 32:
+ dilations = [1, 1]
+ elif output_stride == 16:
+ dilations = [1, 2]
+ elif output_stride == 8:
+ dilations = [2, 4]
+ else:
+ raise NotImplementedError
+ inverted_residual_setting = [
+ # t, c, n, s
+ [1, 16, 1, 1],
+ [6, 24, 2, 2],
+ [6, 32, 3, 2],
+ [6, 64, 4, 2],
+ [6, 96, 3, 1],
+ [6, 160, 3, 2],
+ [6, 320, 1, 1],
+ ]
+ # building first layer
+ input_channels = int(32 * self.multiplier) if self.multiplier > 1.0 else 32
+ # last_channels = int(1280 * multiplier) if multiplier > 1.0 else 1280
+ self.conv1 = _ConvBNReLU(
+ 3, input_channels, 3, 2, 1, relu6=True, norm_layer=norm_layer
+ )
+
+ # building inverted residual blocks
+ self.planes = input_channels
+ self.block1 = self._make_layer(
+ InvertedResidual,
+ self.planes,
+ inverted_residual_setting[0:1],
+ norm_layer=norm_layer,
+ )
+ self.block2 = self._make_layer(
+ InvertedResidual,
+ self.planes,
+ inverted_residual_setting[1:2],
+ norm_layer=norm_layer,
+ )
+ self.block3 = self._make_layer(
+ InvertedResidual,
+ self.planes,
+ inverted_residual_setting[2:3],
+ norm_layer=norm_layer,
+ )
+ self.block4 = self._make_layer(
+ InvertedResidual,
+ self.planes,
+ inverted_residual_setting[3:5],
+ dilations[0],
+ norm_layer=norm_layer,
+ )
+ self.block5 = self._make_layer(
+ InvertedResidual,
+ self.planes,
+ inverted_residual_setting[5:],
+ dilations[1],
+ norm_layer=norm_layer,
+ )
+ self.last_inp_channels = self.planes
+
+ self.up2 = InterpolateNearest2d()
+
+ # weight initialization
+ if not no_init:
+ self.pretrained_path = pretrained_path
+ if pretrained_path is not None:
+ self._load_pretrained_model()
+ else:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out")
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Linear):
+ nn.init.normal_(m.weight, 0, 0.01)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def _make_layer(
+ self,
+ block,
+ planes,
+ inverted_residual_setting,
+ dilation=1,
+ norm_layer=nn.BatchNorm2d,
+ ):
+ features = list()
+ for t, c, n, s in inverted_residual_setting:
+ out_channels = int(c * self.multiplier)
+ stride = s if dilation == 1 else 1
+ features.append(
+ block(planes, out_channels, stride, t, dilation, norm_layer)
+ )
+ planes = out_channels
+ for i in range(n - 1):
+ features.append(
+ block(planes, out_channels, 1, t, norm_layer=norm_layer)
+ )
+ planes = out_channels
+ self.planes = planes
+ return nn.Sequential(*features)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.block1(x)
+ c1 = self.block2(x)
+ c2 = self.block3(c1)
+ c3 = self.block4(c2)
+ c4 = self.up2(self.block5(c3))
+
+ # x = self.features(x)
+ # x = self.classifier(x.view(x.size(0), x.size(1)))
+ return c4, c1
+
+ def _load_pretrained_model(self):
+ assert self.pretrained_path is not None
+ assert Path(self.pretrained_path).exists()
+
+ pretrain_dict = torch.load(self.pretrained_path)
+ pretrain_dict = {k.replace("encoder.", ""): v for k, v in pretrain_dict.items()}
+ model_dict = {}
+ state_dict = self.state_dict()
+ ignored = []
+ for k, v in pretrain_dict.items():
+ if k in state_dict:
+ model_dict[k] = v
+ else:
+ ignored.append(k)
+ state_dict.update(model_dict)
+ self.load_state_dict(state_dict)
+ self.loaded_pre_trained = True
+ print(
+ " - Loaded pre-trained MobileNetV2: ignored {}/{} keys".format(
+ len(ignored), len(pretrain_dict)
+ )
+ )
diff --git a/climategan/deeplab/resnet101_v3.py b/climategan/deeplab/resnet101_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cf6d73b7f41e40e81d45cf180c7bdb216b5eb23
--- /dev/null
+++ b/climategan/deeplab/resnet101_v3.py
@@ -0,0 +1,203 @@
+import torch.nn as nn
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(
+ self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None
+ ):
+ super(Bottleneck, self).__init__()
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
+ self.bn1 = BatchNorm(planes)
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=stride,
+ dilation=dilation,
+ padding=dilation,
+ bias=False,
+ )
+ self.bn2 = BatchNorm(planes)
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = BatchNorm(planes * 4)
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+
+ out = self.conv3(out)
+ out = self.bn3(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNet(nn.Module):
+ def __init__(
+ self, block, layers, output_stride, BatchNorm, verbose=0, no_init=False
+ ):
+ self.inplanes = 64
+ self.verbose = verbose
+ super(ResNet, self).__init__()
+ blocks = [1, 2, 4]
+ if output_stride == 16:
+ strides = [1, 2, 2, 1]
+ dilations = [1, 1, 1, 2]
+ elif output_stride == 8:
+ strides = [1, 2, 1, 1]
+ dilations = [1, 1, 2, 4]
+ else:
+ raise NotImplementedError
+
+ # Modules
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = BatchNorm(64)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ self.layer1 = self._make_layer(
+ block,
+ 64,
+ layers[0],
+ stride=strides[0],
+ dilation=dilations[0],
+ BatchNorm=BatchNorm,
+ )
+ self.layer2 = self._make_layer(
+ block,
+ 128,
+ layers[1],
+ stride=strides[1],
+ dilation=dilations[1],
+ BatchNorm=BatchNorm,
+ )
+ self.layer3 = self._make_layer(
+ block,
+ 256,
+ layers[2],
+ stride=strides[2],
+ dilation=dilations[2],
+ BatchNorm=BatchNorm,
+ )
+ self.layer4 = self._make_MG_unit(
+ block,
+ 512,
+ blocks=blocks,
+ stride=strides[3],
+ dilation=dilations[3],
+ BatchNorm=BatchNorm,
+ )
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ ),
+ BatchNorm(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)
+ )
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)
+ )
+
+ return nn.Sequential(*layers)
+
+ def _make_MG_unit(
+ self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None
+ ):
+ downsample = None
+ if stride != 1 or self.inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ ),
+ BatchNorm(planes * block.expansion),
+ )
+
+ layers = []
+ layers.append(
+ block(
+ self.inplanes,
+ planes,
+ stride,
+ dilation=blocks[0] * dilation,
+ downsample=downsample,
+ BatchNorm=BatchNorm,
+ )
+ )
+ self.inplanes = planes * block.expansion
+ for i in range(1, len(blocks)):
+ layers.append(
+ block(
+ self.inplanes,
+ planes,
+ stride=1,
+ dilation=blocks[i] * dilation,
+ BatchNorm=BatchNorm,
+ )
+ )
+
+ return nn.Sequential(*layers)
+
+ def forward(self, input):
+ x = self.conv1(input)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ low_level_feat = x
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ return x, low_level_feat
+
+
+def ResNet101(output_stride=8, BatchNorm=nn.BatchNorm2d, verbose=0, no_init=False):
+ """Constructs a ResNet-101 model.
+ Args:
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
+ """
+ model = ResNet(
+ Bottleneck,
+ [3, 4, 23, 3],
+ output_stride,
+ BatchNorm,
+ verbose=verbose,
+ no_init=no_init,
+ )
+ return model
diff --git a/climategan/deeplab/resnetmulti_v2.py b/climategan/deeplab/resnetmulti_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe36361f3ea41d182e348ffb98fb9160e718bf88
--- /dev/null
+++ b/climategan/deeplab/resnetmulti_v2.py
@@ -0,0 +1,136 @@
+import torch.nn as nn
+from climategan.blocks import ResBlocks
+
+affine_par = True
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None):
+ super(Bottleneck, self).__init__()
+ # change
+ self.conv1 = nn.Conv2d(
+ inplanes, planes, kernel_size=1, stride=stride, bias=False
+ )
+ self.bn1 = nn.BatchNorm2d(planes, affine=affine_par)
+ for i in self.bn1.parameters():
+ i.requires_grad = False
+ padding = dilation
+ # change
+ self.conv2 = nn.Conv2d(
+ planes,
+ planes,
+ kernel_size=3,
+ stride=1,
+ padding=padding,
+ bias=False,
+ dilation=dilation,
+ )
+ self.bn2 = nn.BatchNorm2d(planes, affine=affine_par)
+ for i in self.bn2.parameters():
+ i.requires_grad = False
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * 4, affine=affine_par)
+ for i in self.bn3.parameters():
+ i.requires_grad = False
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+
+ def forward(self, x):
+ residual = x
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ out = self.relu(out)
+ out = self.conv3(out)
+ out = self.bn3(out)
+ if self.downsample is not None:
+ residual = self.downsample(x)
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ResNetMulti(nn.Module):
+ def __init__(
+ self,
+ layers,
+ n_res=4,
+ res_norm="instance",
+ activ="lrelu",
+ pad_type="reflect",
+ ):
+ self.inplanes = 64
+ block = Bottleneck
+ super(ResNetMulti, self).__init__()
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = nn.BatchNorm2d(64, affine=affine_par)
+ for i in self.bn1.parameters():
+ i.requires_grad = False
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(
+ kernel_size=3, stride=2, padding=0, ceil_mode=True
+ ) # changed padding from 1 to 0
+ self.layer1 = self._make_layer(block, 64, layers[0])
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2)
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ m.weight.data.normal_(0, 0.01)
+ elif isinstance(m, nn.BatchNorm2d):
+ m.weight.data.fill_(1)
+ m.bias.data.zero_()
+ self.layer_res = ResBlocks(
+ n_res, 2048, norm=res_norm, activation=activ, pad_type=pad_type
+ )
+
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1):
+ downsample = None
+ if (
+ stride != 1
+ or self.inplanes != planes * block.expansion
+ or dilation == 2
+ or dilation == 4
+ ):
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ self.inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False,
+ ),
+ nn.BatchNorm2d(planes * block.expansion, affine=affine_par),
+ )
+ for i in downsample._modules["1"].parameters():
+ i.requires_grad = False
+ layers = []
+ layers.append(
+ block(
+ self.inplanes, planes, stride, dilation=dilation, downsample=downsample
+ )
+ )
+ self.inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(block(self.inplanes, planes, dilation=dilation))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ x = self.layer_res(x)
+ return x
diff --git a/climategan/depth.py b/climategan/depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8d408448b82b1d11043131b61897b8467192e65
--- /dev/null
+++ b/climategan/depth.py
@@ -0,0 +1,230 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from climategan.blocks import BaseDecoder, Conv2dBlock, InterpolateNearest2d
+from climategan.utils import find_target_size
+
+
+def create_depth_decoder(opts, no_init=False, verbose=0):
+ if opts.gen.d.architecture == "base":
+ decoder = BaseDepthDecoder(opts)
+ if "s" in opts.task:
+ assert opts.gen.s.use_dada is False
+ if "m" in opts.tasks:
+ assert opts.gen.m.use_dada is False
+ else:
+ decoder = DADADepthDecoder(opts)
+
+ if verbose > 0:
+ print(f" - Add {decoder.__class__.__name__}")
+
+ return decoder
+
+
+class DADADepthDecoder(nn.Module):
+ """
+ Depth decoder based on depth auxiliary task in DADA paper
+ """
+
+ def __init__(self, opts):
+ super().__init__()
+ if (
+ opts.gen.encoder.architecture == "deeplabv3"
+ and opts.gen.deeplabv3.backbone == "mobilenet"
+ ):
+ res_dim = 320
+ else:
+ res_dim = 2048
+
+ mid_dim = 512
+
+ self.do_feat_fusion = False
+ if opts.gen.m.use_dada or ("s" in opts.tasks and opts.gen.s.use_dada):
+ self.do_feat_fusion = True
+ self.dec4 = Conv2dBlock(
+ 128,
+ res_dim,
+ 1,
+ stride=1,
+ padding=0,
+ bias=True,
+ activation="lrelu",
+ norm="none",
+ )
+
+ self.relu = nn.ReLU(inplace=True)
+ self.enc4_1 = Conv2dBlock(
+ res_dim,
+ mid_dim,
+ 1,
+ stride=1,
+ padding=0,
+ bias=False,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="batch",
+ )
+ self.enc4_2 = Conv2dBlock(
+ mid_dim,
+ mid_dim,
+ 3,
+ stride=1,
+ padding=1,
+ bias=False,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="batch",
+ )
+ self.enc4_3 = Conv2dBlock(
+ mid_dim,
+ 128,
+ 1,
+ stride=1,
+ padding=0,
+ bias=False,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="batch",
+ )
+ self.upsample = None
+ if opts.gen.d.upsample_featuremaps:
+ self.upsample = nn.Sequential(
+ *[
+ InterpolateNearest2d(),
+ Conv2dBlock(
+ 128,
+ 32,
+ 3,
+ stride=1,
+ padding=1,
+ bias=False,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="batch",
+ ),
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
+ ]
+ )
+ self._target_size = find_target_size(opts, "d")
+ print(
+ " - {}: setting target size to {}".format(
+ self.__class__.__name__, self._target_size
+ )
+ )
+
+ def set_target_size(self, size):
+ """
+ Set final interpolation's target size
+
+ Args:
+ size (int, list, tuple): target size (h, w). If int, target will be (i, i)
+ """
+ if isinstance(size, (list, tuple)):
+ self._target_size = size[:2]
+ else:
+ self._target_size = (size, size)
+
+ def forward(self, z):
+ if isinstance(z, (list, tuple)):
+ z = z[0]
+ z4_enc = self.enc4_1(z)
+ z4_enc = self.enc4_2(z4_enc)
+ z4_enc = self.enc4_3(z4_enc)
+
+ z_depth = None
+ if self.do_feat_fusion:
+ z_depth = self.dec4(z4_enc)
+
+ if self.upsample is not None:
+ z4_enc = self.upsample(z4_enc)
+
+ depth = torch.mean(z4_enc, dim=1, keepdim=True) # DADA paper decoder
+ if depth.shape[-1] != self._target_size:
+ depth = F.interpolate(
+ depth,
+ size=(384, 384), # size used in MiDaS inference
+ mode="bicubic", # what MiDaS uses
+ align_corners=False,
+ )
+
+ depth = F.interpolate(
+ depth, (self._target_size, self._target_size), mode="nearest"
+ ) # what we used in the transforms to resize input
+
+ return depth, z_depth
+
+ def __str__(self):
+ return "DADA Depth Decoder"
+
+
+class BaseDepthDecoder(BaseDecoder):
+ def __init__(self, opts):
+ low_level_feats_dim = -1
+ use_v3 = opts.gen.encoder.architecture == "deeplabv3"
+ use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet"
+ use_low = opts.gen.d.use_low_level_feats
+
+ if use_v3 and use_mobile_net:
+ input_dim = 320
+ if use_low:
+ low_level_feats_dim = 24
+ elif use_v3:
+ input_dim = 2048
+ if use_low:
+ low_level_feats_dim = 256
+ else:
+ input_dim = 2048
+
+ n_upsample = 1 if opts.gen.d.upsample_featuremaps else 0
+ output_dim = (
+ 1
+ if not opts.gen.d.classify.enable
+ else opts.gen.d.classify.linspace.buckets
+ )
+
+ self._target_size = find_target_size(opts, "d")
+ print(
+ " - {}: setting target size to {}".format(
+ self.__class__.__name__, self._target_size
+ )
+ )
+
+ super().__init__(
+ n_upsample=n_upsample,
+ n_res=opts.gen.d.n_res,
+ input_dim=input_dim,
+ proj_dim=opts.gen.d.proj_dim,
+ output_dim=output_dim,
+ norm=opts.gen.d.norm,
+ activ=opts.gen.d.activ,
+ pad_type=opts.gen.d.pad_type,
+ output_activ="none",
+ low_level_feats_dim=low_level_feats_dim,
+ )
+
+ def set_target_size(self, size):
+ """
+ Set final interpolation's target size
+
+ Args:
+ size (int, list, tuple): target size (h, w). If int, target will be (i, i)
+ """
+ if isinstance(size, (list, tuple)):
+ self._target_size = size[:2]
+ else:
+ self._target_size = (size, size)
+
+ def forward(self, z, cond=None):
+ if self._target_size is None:
+ error = "self._target_size should be set with self.set_target_size()"
+ error += "to interpolate depth to the target depth map's size"
+ raise ValueError(error)
+
+ d = super().forward(z)
+
+ preds = F.interpolate(
+ d, size=self._target_size, mode="bilinear", align_corners=True
+ )
+
+ return preds, None
diff --git a/climategan/discriminator.py b/climategan/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..97404bebae0a1751489f8962cda31be2b1f02303
--- /dev/null
+++ b/climategan/discriminator.py
@@ -0,0 +1,361 @@
+"""Discriminator architecture for ClimateGAN's GAN components (a and t)
+"""
+import functools
+
+import torch
+import torch.nn as nn
+
+from climategan.blocks import SpectralNorm
+from climategan.tutils import init_weights
+
+# from torch.optim import lr_scheduler
+
+# mainly from https://github.com/sangwoomo/instagan/blob/master/models/networks.py
+
+
+def create_discriminator(opts, device, no_init=False, verbose=0):
+ disc = OmniDiscriminator(opts)
+ if no_init:
+ return disc
+
+ for task, model in disc.items():
+ if isinstance(model, nn.ModuleDict):
+ for domain, domain_model in model.items():
+ init_weights(
+ domain_model,
+ init_type=opts.dis[task].init_type,
+ init_gain=opts.dis[task].init_gain,
+ verbose=verbose,
+ caller=f"create_discriminator {task} {domain}",
+ )
+ else:
+ init_weights(
+ model,
+ init_type=opts.dis[task].init_type,
+ init_gain=opts.dis[task].init_gain,
+ verbose=verbose,
+ caller=f"create_discriminator {task}",
+ )
+ return disc.to(device)
+
+
+def define_D(
+ input_nc,
+ ndf,
+ n_layers=3,
+ norm="batch",
+ use_sigmoid=False,
+ get_intermediate_features=False,
+ num_D=1,
+):
+ norm_layer = get_norm_layer(norm_type=norm)
+ net = MultiscaleDiscriminator(
+ input_nc,
+ ndf,
+ n_layers=n_layers,
+ norm_layer=norm_layer,
+ use_sigmoid=use_sigmoid,
+ get_intermediate_features=get_intermediate_features,
+ num_D=num_D,
+ )
+ return net
+
+
+def get_norm_layer(norm_type="instance"):
+ if not norm_type:
+ print("norm_type is {}, defaulting to instance")
+ norm_type = "instance"
+ if norm_type == "batch":
+ norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
+ elif norm_type == "instance":
+ norm_layer = functools.partial(
+ nn.InstanceNorm2d, affine=False, track_running_stats=False
+ )
+ elif norm_type == "none":
+ norm_layer = None
+ else:
+ raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
+ return norm_layer
+
+
+# Defines the PatchGAN discriminator with the specified arguments.
+class NLayerDiscriminator(nn.Module):
+ def __init__(
+ self,
+ input_nc=3,
+ ndf=64,
+ n_layers=3,
+ norm_layer=nn.BatchNorm2d,
+ use_sigmoid=False,
+ get_intermediate_features=True,
+ ):
+ super(NLayerDiscriminator, self).__init__()
+ if type(norm_layer) == functools.partial:
+ use_bias = norm_layer.func == nn.InstanceNorm2d
+ else:
+ use_bias = norm_layer == nn.InstanceNorm2d
+
+ self.get_intermediate_features = get_intermediate_features
+
+ kw = 4
+ padw = 1
+ sequence = [
+ [
+ # Use spectral normalization
+ SpectralNorm(
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw)
+ ),
+ nn.LeakyReLU(0.2, True),
+ ]
+ ]
+
+ nf_mult = 1
+ nf_mult_prev = 1
+ for n in range(1, n_layers):
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n, 8)
+ sequence += [
+ [
+ # Use spectral normalization
+ SpectralNorm( # TODO replace with Conv2dBlock
+ nn.Conv2d(
+ ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=2,
+ padding=padw,
+ bias=use_bias,
+ )
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+ ]
+
+ nf_mult_prev = nf_mult
+ nf_mult = min(2 ** n_layers, 8)
+ sequence += [
+ [
+ # Use spectral normalization
+ SpectralNorm(
+ nn.Conv2d(
+ ndf * nf_mult_prev,
+ ndf * nf_mult,
+ kernel_size=kw,
+ stride=1,
+ padding=padw,
+ bias=use_bias,
+ )
+ ),
+ norm_layer(ndf * nf_mult),
+ nn.LeakyReLU(0.2, True),
+ ]
+ ]
+
+ # Use spectral normalization
+ sequence += [
+ [
+ SpectralNorm(
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
+ )
+ ]
+ ]
+
+ if use_sigmoid:
+ sequence += [[nn.Sigmoid()]]
+
+ # We divide the layers into groups to extract intermediate layer outputs
+ for n in range(len(sequence)):
+ self.add_module("model" + str(n), nn.Sequential(*sequence[n]))
+ # self.model = nn.Sequential(*sequence)
+
+ def forward(self, input):
+ results = [input]
+ for submodel in self.children():
+ intermediate_output = submodel(results[-1])
+ results.append(intermediate_output)
+
+ get_intermediate_features = self.get_intermediate_features
+ if get_intermediate_features:
+ return results[1:]
+ else:
+ return results[-1]
+
+
+# def forward(self, input):
+# return self.model(input)
+
+
+# Source: https://github.com/NVIDIA/pix2pixHD
+class MultiscaleDiscriminator(nn.Module):
+ def __init__(
+ self,
+ input_nc=3,
+ ndf=64,
+ n_layers=3,
+ norm_layer=nn.BatchNorm2d,
+ use_sigmoid=False,
+ get_intermediate_features=True,
+ num_D=3,
+ ):
+ super(MultiscaleDiscriminator, self).__init__()
+ # self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d,
+ # use_sigmoid=False, num_D=3, getIntermFeat=False
+
+ self.n_layers = n_layers
+ self.ndf = ndf
+ self.norm_layer = norm_layer
+ self.use_sigmoid = use_sigmoid
+ self.get_intermediate_features = get_intermediate_features
+ self.num_D = num_D
+
+ for i in range(self.num_D):
+ netD = NLayerDiscriminator(
+ input_nc=input_nc,
+ ndf=self.ndf,
+ n_layers=self.n_layers,
+ norm_layer=self.norm_layer,
+ use_sigmoid=self.use_sigmoid,
+ get_intermediate_features=self.get_intermediate_features,
+ )
+ self.add_module("discriminator_%d" % i, netD)
+
+ self.downsample = nn.AvgPool2d(
+ 3, stride=2, padding=[1, 1], count_include_pad=False
+ )
+
+ def forward(self, input):
+ result = []
+ get_intermediate_features = self.get_intermediate_features
+ for name, D in self.named_children():
+ if "discriminator" not in name:
+ continue
+ out = D(input)
+ if not get_intermediate_features:
+ out = [out]
+ result.append(out)
+ input = self.downsample(input)
+
+ return result
+
+
+class OmniDiscriminator(nn.ModuleDict):
+ def __init__(self, opts):
+ super().__init__()
+ if "p" in opts.tasks:
+ if opts.dis.p.use_local_discriminator:
+
+ self["p"] = nn.ModuleDict(
+ {
+ "global": define_D(
+ input_nc=3,
+ ndf=opts.dis.p.ndf,
+ n_layers=opts.dis.p.n_layers,
+ norm=opts.dis.p.norm,
+ use_sigmoid=opts.dis.p.use_sigmoid,
+ get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501
+ num_D=opts.dis.p.num_D,
+ ),
+ "local": define_D(
+ input_nc=3,
+ ndf=opts.dis.p.ndf,
+ n_layers=opts.dis.p.n_layers,
+ norm=opts.dis.p.norm,
+ use_sigmoid=opts.dis.p.use_sigmoid,
+ get_intermediate_features=opts.dis.p.get_intermediate_features, # noqa: E501
+ num_D=opts.dis.p.num_D,
+ ),
+ }
+ )
+ else:
+ self["p"] = define_D(
+ input_nc=4, # image + mask
+ ndf=opts.dis.p.ndf,
+ n_layers=opts.dis.p.n_layers,
+ norm=opts.dis.p.norm,
+ use_sigmoid=opts.dis.p.use_sigmoid,
+ get_intermediate_features=opts.dis.p.get_intermediate_features,
+ num_D=opts.dis.p.num_D,
+ )
+ if "m" in opts.tasks:
+ if opts.gen.m.use_advent:
+ if opts.dis.m.architecture == "base":
+ if opts.dis.m.gan_type == "WGAN_norm":
+ self["m"] = nn.ModuleDict(
+ {
+ "Advent": get_fc_discriminator(
+ num_classes=2, use_norm=True
+ )
+ }
+ )
+ else:
+ self["m"] = nn.ModuleDict(
+ {
+ "Advent": get_fc_discriminator(
+ num_classes=2, use_norm=False
+ )
+ }
+ )
+ elif opts.dis.m.architecture == "OmniDiscriminator":
+ self["m"] = nn.ModuleDict(
+ {
+ "Advent": define_D(
+ input_nc=2,
+ ndf=opts.dis.m.ndf,
+ n_layers=opts.dis.m.n_layers,
+ norm=opts.dis.m.norm,
+ use_sigmoid=opts.dis.m.use_sigmoid,
+ get_intermediate_features=opts.dis.m.get_intermediate_features, # noqa: E501
+ num_D=opts.dis.m.num_D,
+ )
+ }
+ )
+ else:
+ raise Exception("This Discriminator is currently not supported!")
+ if "s" in opts.tasks:
+ if opts.gen.s.use_advent:
+ if opts.dis.s.gan_type == "WGAN_norm":
+ self["s"] = nn.ModuleDict(
+ {"Advent": get_fc_discriminator(num_classes=11, use_norm=True)}
+ )
+ else:
+ self["s"] = nn.ModuleDict(
+ {"Advent": get_fc_discriminator(num_classes=11, use_norm=False)}
+ )
+
+
+def get_fc_discriminator(num_classes=2, ndf=64, use_norm=False):
+ if use_norm:
+ return torch.nn.Sequential(
+ SpectralNorm(
+ torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1)
+ ),
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ SpectralNorm(
+ torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1)
+ ),
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ SpectralNorm(
+ torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1)
+ ),
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ SpectralNorm(
+ torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1)
+ ),
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ SpectralNorm(
+ torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1)
+ ),
+ )
+ else:
+ return torch.nn.Sequential(
+ torch.nn.Conv2d(num_classes, ndf, kernel_size=4, stride=2, padding=1),
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ torch.nn.Conv2d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ torch.nn.Conv2d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ torch.nn.Conv2d(ndf * 4, ndf * 8, kernel_size=4, stride=2, padding=1),
+ torch.nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ torch.nn.Conv2d(ndf * 8, 1, kernel_size=4, stride=2, padding=1),
+ )
diff --git a/climategan/eval_metrics.py b/climategan/eval_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..b985413c2595339258ca473673989bd9b4ab2d27
--- /dev/null
+++ b/climategan/eval_metrics.py
@@ -0,0 +1,635 @@
+import cv2
+import numpy as np
+import torch
+from skimage import filters
+from sklearn.metrics.pairwise import euclidean_distances
+import matplotlib.pyplot as plt
+import seaborn as sns
+from copy import deepcopy
+
+# ------------------------------------------------------------------------------
+# ----- Evaluation metrics for a pair of binary mask images (pred, target) -----
+# ------------------------------------------------------------------------------
+
+
+def get_accuracy(arr1, arr2):
+ """pixel accuracy
+
+ Args:
+ arr1 (np.array)
+ arr2 (np.array)
+ """
+ return (arr1 == arr2).sum() / arr1.size
+
+
+def trimap(pred_im, gt_im, thickness=8):
+ """Compute accuracy in a region of thickness around the contours
+ for binary images (0-1 values)
+ Args:
+ pred_im (Image): Prediction
+ gt_im (Image): Target
+ thickness (int, optional): [description]. Defaults to 8.
+ """
+ W, H = gt_im.size
+ contours, hierarchy = cv2.findContours(
+ np.array(gt_im), mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_SIMPLE
+ )
+ mask_contour = np.zeros((H, W), dtype=np.int32)
+ cv2.drawContours(
+ mask_contour, contours, -1, (1), thickness=thickness, hierarchy=hierarchy
+ )
+ gt_contour = np.array(gt_im)[np.where(mask_contour > 0)]
+ pred_contour = np.array(pred_im)[np.where(mask_contour > 0)]
+ return get_accuracy(pred_contour, gt_contour)
+
+
+def iou(pred_im, gt_im):
+ """
+ IoU for binary masks (0-1 values)
+
+ Args:
+ pred_im ([type]): [description]
+ gt_im ([type]): [description]
+ """
+ pred = np.array(pred_im)
+ gt = np.array(gt_im)
+ intersection = (pred * gt).sum()
+ union = (pred + gt).sum() - intersection
+ return intersection / union
+
+
+def f1_score(pred_im, gt_im):
+ pred = np.array(pred_im)
+ gt = np.array(gt_im)
+ intersection = (pred * gt).sum()
+ return 2 * intersection / (pred + gt).sum()
+
+
+def accuracy(pred_im, gt_im):
+ pred = np.array(pred_im)
+ gt = np.array(gt_im)
+ if len(gt_im.shape) == 4:
+ assert gt_im.shape[1] == 1
+ gt_im = gt_im[:, 0, :, :]
+ if len(pred.shape) > len(gt_im.shape):
+ pred = np.argmax(pred, axis=1)
+ return float((pred == gt).sum()) / gt.size
+
+
+def mIOU(pred, label, average="macro"):
+ """
+ Adapted from:
+ https://stackoverflow.com/questions/62461379/multiclass-semantic-segmentation-model-evaluation
+
+ Compute the mean IOU from pred and label tensors
+ pred is a tensor N x C x H x W with logits (softmax will be applied)
+ and label is a N x H x W tensor with int labels per pixel
+
+ this does the same as sklearn's jaccard_score function if you choose average="macro"
+ Args:
+ pred (torch.tensor): predicted logits
+ label (torch.tensor): labels
+ average: "macro" or "weighted"
+
+ Returns:
+ float: mIOU, can be nan
+ """
+ num_classes = pred.shape[-3]
+
+ pred = torch.argmax(pred, dim=1).squeeze(1)
+ present_iou_list = list()
+ pred = pred.view(-1)
+ label = label.view(-1)
+ # Note: Following for loop goes from 0 to (num_classes-1)
+ # and ignore_index is num_classes, thus ignore_index is
+ # not considered in computation of IoU.
+ interesting_classes = (
+ [*range(num_classes)] if num_classes > 2 else [int(label.max().item())]
+ )
+ weights = []
+
+ for sem_class in interesting_classes:
+ pred_inds = pred == sem_class
+ target_inds = label == sem_class
+ if (target_inds.long().sum().item() > 0) or (pred_inds.long().sum().item() > 0):
+ intersection_now = (pred_inds[target_inds]).long().sum().item()
+ union_now = (
+ pred_inds.long().sum().item()
+ + target_inds.long().sum().item()
+ - intersection_now
+ )
+ weights.append(pred_inds.long().sum().item())
+ iou_now = float(intersection_now) / float(union_now)
+ present_iou_list.append(iou_now)
+ if not present_iou_list:
+ return float("nan")
+ elif average == "weighted":
+ weighted_avg = np.sum(np.multiply(weights, present_iou_list) / np.sum(weights))
+ return weighted_avg
+ else:
+ return np.mean(present_iou_list)
+
+
+def masker_classification_metrics(
+ pred, label, labels_dict={"cannot": 0, "must": 1, "may": 2}
+):
+ """
+ Classification metrics for the masker, and the corresponding maps. If the
+ predictions are soft, the errors are weighted accordingly. Metrics computed:
+
+ tpr : float
+ True positive rate
+
+ tpt : float
+ True positive total (divided by total population)
+
+ tnr : float
+ True negative rate
+
+ tnt : float
+ True negative total (divided by total population)
+
+ fpr : float
+ False positive rate: rate of predicted mask on cannot flood
+
+ fpt : float
+ False positive total (divided by total population)
+
+ fnr : float
+ False negative rate: rate of missed mask on must flood
+
+ fnt : float
+ False negative total (divided by total population)
+
+ mnr : float
+ "May" negative rate (labeled as "may", predicted as no-mask)
+
+ mpr : float
+ "May" positive rate (labeled as "may", predicted as mask)
+
+ accuracy : float
+ Accuracy
+
+ error : float
+ Error
+
+ precision : float
+ Precision, considering only cannot and must flood labels
+
+ f05 : float
+ F0.5 score, considering only cannot and must flood labels
+
+ accuracy_must_may : float
+ Accuracy considering only the must and may areas
+
+ Parameters
+ ----------
+ pred : array-like
+ Mask prediction
+
+ label : array-like
+ Mask ground truth labels
+
+ labels_dict : dict
+ A dictionary with the identifier of each class (cannot, must, may)
+
+ Returns
+ -------
+ metrics_dict : dict
+ A dictionary with metric name and value pairs
+
+ maps_dict : dict
+ A dictionary containing the metric maps
+ """
+ tp_map = pred * np.asarray(label == labels_dict["must"], dtype=int)
+ tpr = np.sum(tp_map) / np.sum(label == labels_dict["must"])
+ tpt = np.sum(tp_map) / np.prod(label.shape)
+ tn_map = (1.0 - pred) * np.asarray(label == labels_dict["cannot"], dtype=int)
+ tnr = np.sum(tn_map) / np.sum(label == labels_dict["cannot"])
+ tnt = np.sum(tn_map) / np.prod(label.shape)
+ fp_map = pred * np.asarray(label == labels_dict["cannot"], dtype=int)
+ fpr = np.sum(fp_map) / np.sum(label == labels_dict["cannot"])
+ fpt = np.sum(fp_map) / np.prod(label.shape)
+ fn_map = (1.0 - pred) * np.asarray(label == labels_dict["must"], dtype=int)
+ fnr = np.sum(fn_map) / np.sum(label == labels_dict["must"])
+ fnt = np.sum(fn_map) / np.prod(label.shape)
+ may_neg_map = (1.0 - pred) * np.asarray(label == labels_dict["may"], dtype=int)
+ may_pos_map = pred * np.asarray(label == labels_dict["may"], dtype=int)
+ mnr = np.sum(may_neg_map) / np.sum(label == labels_dict["may"])
+ mpr = np.sum(may_pos_map) / np.sum(label == labels_dict["may"])
+ accuracy = tpt + tnt
+ error = fpt + fnt
+
+ # Assertions
+ assert np.isclose(tpr, 1.0 - fnr), "TPR: {:.4f}, FNR: {:.4f}".format(tpr, fnr)
+ assert np.isclose(tnr, 1.0 - fpr), "TNR: {:.4f}, FPR: {:.4f}".format(tnr, fpr)
+ assert np.isclose(mpr, 1.0 - mnr), "MPR: {:.4f}, MNR: {:.4f}".format(mpr, mnr)
+
+ precision = np.sum(tp_map) / (np.sum(tp_map) + np.sum(fp_map) + 1e-9)
+ beta = 0.5
+ f05 = ((1 + beta ** 2) * precision * tpr) / (beta ** 2 * precision + tpr + 1e-9)
+ accuracy_must_may = (np.sum(tp_map) + np.sum(may_neg_map)) / (
+ np.sum(label == labels_dict["must"]) + np.sum(label == labels_dict["may"])
+ )
+
+ metrics_dict = {
+ "tpr": tpr,
+ "tpt": tpt,
+ "tnr": tnr,
+ "tnt": tnt,
+ "fpr": fpr,
+ "fpt": fpt,
+ "fnr": fnr,
+ "fnt": fnt,
+ "mpr": mpr,
+ "mnr": mnr,
+ "accuracy": accuracy,
+ "error": error,
+ "precision": precision,
+ "f05": f05,
+ "accuracy_must_may": accuracy_must_may,
+ }
+ maps_dict = {
+ "tp": tp_map,
+ "tn": tn_map,
+ "fp": fp_map,
+ "fn": fn_map,
+ "may_pos": may_pos_map,
+ "may_neg": may_neg_map,
+ }
+
+ return metrics_dict, maps_dict
+
+
+def pred_cannot(pred, label, label_cannot=0):
+ """
+ Metric for the masker: Computes false positive rate and its map. If the
+ predictions are soft, the errors are weighted accordingly.
+
+ Parameters
+ ----------
+ pred : array-like
+ Mask prediction
+
+ label : array-like
+ Mask ground truth labels
+
+ label_cannot : int
+ The label index of "cannot flood"
+
+ Returns
+ -------
+ fp_map : array-like
+ The map of false positives: predicted mask on cannot flood
+
+ fpr : float
+ False positive rate: rate of predicted mask on cannot flood
+ """
+ fp_map = pred * np.asarray(label == label_cannot, dtype=int)
+ fpr = np.sum(fp_map) / np.sum(label == label_cannot)
+ return fp_map, fpr
+
+
+def missed_must(pred, label, label_must=1):
+ """
+ Metric for the masker: Computes false negative rate and its map. If the
+ predictions are soft, the errors are weighted accordingly.
+
+ Parameters
+ ----------
+ pred : array-like
+ Mask prediction
+
+ label : array-like
+ Mask ground truth labels
+
+ label_must : int
+ The label index of "must flood"
+
+ Returns
+ -------
+ fn_map : array-like
+ The map of false negatives: missed mask on must flood
+
+ fnr : float
+ False negative rate: rate of missed mask on must flood
+ """
+ fn_map = (1.0 - pred) * np.asarray(label == label_must, dtype=int)
+ fnr = np.sum(fn_map) / np.sum(label == label_must)
+ return fn_map, fnr
+
+
+def may_flood(pred, label, label_may=2):
+ """
+ Metric for the masker: Computes "may" negative and "may" positive rates and their
+ map. If the predictions are soft, the "errors" are weighted accordingly.
+
+ Parameters
+ ----------
+ pred : array-like
+ Mask prediction
+
+ label : array-like
+ Mask ground truth labels
+
+ label_may : int
+ The label index of "may flood"
+
+ Returns
+ -------
+ may_neg_map : array-like
+ The map of "may" negatives
+
+ may_pos_map : array-like
+ The map of "may" positives
+
+ mnr : float
+ "May" negative rate
+
+ mpr : float
+ "May" positive rate
+ """
+ may_neg_map = (1.0 - pred) * np.asarray(label == label_may, dtype=int)
+ may_pos_map = pred * np.asarray(label == label_may, dtype=int)
+ mnr = np.sum(may_neg_map) / np.sum(label == label_may)
+ mpr = np.sum(may_pos_map) / np.sum(label == label_may)
+ return may_neg_map, may_pos_map, mnr, mpr
+
+
+def masker_metrics(pred, label, label_cannot=0, label_must=1):
+ """
+ Computes a set of metrics for the masker
+
+ Parameters
+ ----------
+ pred : array-like
+ Mask prediction
+
+ label : array-like
+ Mask ground truth labels
+
+ label_must : int
+ The label index of "must flood"
+
+ label_cannot : int
+ The label index of "cannot flood"
+
+ Returns
+ -------
+ tpr : float
+ True positive rate
+
+ tnr : float
+ True negative rate
+
+ precision : float
+ Precision, considering only cannot and must flood labels
+
+ f1 : float
+ F1 score, considering only cannot and must flood labels
+ """
+ tp_map = pred * np.asarray(label == label_must, dtype=int)
+ tpr = np.sum(tp_map) / np.sum(label == label_must)
+ tn_map = (1.0 - pred) * np.asarray(label == label_cannot, dtype=int)
+ tnr = np.sum(tn_map) / np.sum(label == label_cannot)
+ fp_map = pred * np.asarray(label == label_cannot, dtype=int)
+ fn_map = (1.0 - pred) * np.asarray(label == label_must, dtype=int) # noqa: F841
+ precision = np.sum(tp_map) / (np.sum(tp_map) + np.sum(fp_map))
+ f1 = 2 * (precision * tpr) / (precision + tpr)
+ return tpr, tnr, precision, f1
+
+
+def get_confusion_matrix(tpr, tnr, fpr, fnr, mpr, mnr):
+ """
+ Constructs the confusion matrix of a masker prediction over a set of samples
+
+ Parameters
+ ----------
+ tpr : vector-like
+ True positive rate
+
+ tnr : vector-like
+ True negative rate
+
+ fpr : vector-like
+ False positive rate
+
+ fnr : vector-like
+ False negative rate
+
+ mpr : vector-like
+ "May" positive rate
+
+ mnr : vector-like
+ "May" negative rate
+
+ Returns
+ -------
+ confusion_matrix : 3x3 array
+ Confusion matrix: [i, j] = [pred, true]
+ | tnr fnr mnr |
+ | fpr tpr mpr |
+ | 0. 0, 0, |
+
+ confusion_matrix_std : 3x3 array
+ Standard deviation of the confusion matrix
+ """
+ # Compute mean and standard deviations over all samples
+ tpr_m = np.mean(tpr)
+ tpr_s = np.std(tpr)
+ tnr_m = np.mean(tnr)
+ tnr_s = np.std(tnr)
+ fpr_m = np.mean(fpr)
+ fpr_s = np.std(fpr)
+ fnr_m = np.mean(fnr)
+ fnr_s = np.std(fnr)
+ mpr_m = np.mean(mpr)
+ mpr_s = np.std(mpr)
+ mnr_m = np.mean(mnr)
+ mnr_s = np.std(mnr)
+
+ # Assertions
+ assert np.isclose(tpr_m, 1.0 - fnr_m), "TPR: {:.4f}, FNR: {:.4f}".format(
+ tpr_m, fnr_m
+ )
+ assert np.isclose(tnr_m, 1.0 - fpr_m), "TNR: {:.4f}, FPR: {:.4f}".format(
+ tnr_m, fpr_m
+ )
+ assert np.isclose(mpr_m, 1.0 - mnr_m), "MPR: {:.4f}, MNR: {:.4f}".format(
+ mpr_m, mnr_m
+ )
+
+ # Fill confusion matrix
+ confusion_matrix = np.zeros((3, 3))
+ confusion_matrix[0, 0] = tnr_m
+ confusion_matrix[0, 1] = fnr_m
+ confusion_matrix[0, 2] = mnr_m
+ confusion_matrix[1, 0] = fpr_m
+ confusion_matrix[1, 1] = tpr_m
+ confusion_matrix[1, 2] = mpr_m
+ confusion_matrix[2, 2] = 0.0
+
+ # Standard deviation
+ confusion_matrix_std = np.zeros((3, 3))
+ confusion_matrix_std[0, 0] = tnr_s
+ confusion_matrix_std[0, 1] = fnr_s
+ confusion_matrix_std[0, 2] = mnr_s
+ confusion_matrix_std[1, 0] = fpr_s
+ confusion_matrix_std[1, 1] = tpr_s
+ confusion_matrix_std[1, 2] = mpr_s
+ confusion_matrix_std[2, 2] = 0.0
+ return confusion_matrix, confusion_matrix_std
+
+
+def edges_coherence_std_min(pred, label, label_must=1, bin_th=0.5):
+ """
+ The standard deviation of the minimum distance between the edge of the prediction
+ and the edge of the "must flood" label.
+
+ Parameters
+ ----------
+ pred : array-like
+ Mask prediction
+
+ label : array-like
+ Mask ground truth labels
+
+ label_must : int
+ The label index of "must flood"
+
+ bin_th : float
+ The threshold for the binarization of the prediction
+
+ Returns
+ -------
+ metric : float
+ The value of the metric
+
+ pred_edge : array-like
+ The edges images of the prediction, for visualization
+
+ label_edge : array-like
+ The edges images of the "must flood" label, for visualization
+ """
+ # Keep must flood label only
+ label = deepcopy(label)
+ label[label != label_must] = -1
+ label[label == label_must] = 1
+ label[label != label_must] = 0
+ label = np.asarray(label, dtype=float)
+
+ # Binarize prediction
+ pred = np.asarray(pred > bin_th, dtype=float)
+
+ # Compute edges
+ pred = filters.sobel(pred)
+ label = filters.sobel(label)
+
+ # Location of edges
+ pred_coord = np.argwhere(pred > 0)
+ label_coord = np.argwhere(label > 0)
+
+ # Handle blank predictions
+ if pred_coord.shape[0] == 0:
+ return 1.0, pred, label
+
+ # Normalized pairwise distances between pred and label
+ dist_mat = np.divide(euclidean_distances(pred_coord, label_coord), pred.shape[0])
+
+ # Standard deviation of the minimum distance from pred to label
+ edge_coherence = np.std(np.min(dist_mat, axis=1))
+
+ return edge_coherence, pred, label
+
+
+def boxplot_metric(
+ output_filename,
+ df,
+ metric,
+ dict_metrics,
+ do_stripplot=False,
+ dict_models=None,
+ dpi=300,
+ **snskwargs
+):
+ f = plt.figure(dpi=dpi)
+
+ if do_stripplot:
+ ax = sns.boxplot(x="model", y=metric, data=df, fliersize=0.0, **snskwargs)
+ ax = sns.stripplot(
+ x="model", y=metric, data=df, size=2.0, color="gray", **snskwargs
+ )
+ else:
+ ax = sns.boxplot(x="model", y=metric, data=df, **snskwargs)
+
+ # Set axes labels
+ ax.set_xlabel("Models", rotation=0, fontsize="medium")
+ ax.set_ylabel(dict_metrics[metric], rotation=90, fontsize="medium")
+
+ # Spines
+ sns.despine(left=True, bottom=True)
+
+ # X-Tick labels
+ if dict_models:
+ xticklabels = [dict_models[t.get_text()] for t in ax.get_xticklabels()]
+ ax.set_xticklabels(
+ xticklabels,
+ rotation=20,
+ verticalalignment="top",
+ horizontalalignment="right",
+ fontsize="xx-small",
+ )
+
+ f.savefig(
+ output_filename,
+ dpi=f.dpi,
+ bbox_inches="tight",
+ facecolor="white",
+ transparent=False,
+ )
+ f.clear()
+ plt.close(f)
+
+
+def clustermap_metric(
+ output_filename,
+ df,
+ metric,
+ dict_metrics,
+ method="average",
+ cluster_metric="euclidean",
+ dict_models=None,
+ dpi=300,
+ **snskwargs
+):
+ ax_grid = sns.clustermap(data=df, method=method, metric=cluster_metric, **snskwargs)
+ ax_heatmap = ax_grid.ax_heatmap
+ ax_cbar = ax_grid.ax_cbar
+
+ # Set axes labels
+ ax_heatmap.set_xlabel("Models", rotation=0, fontsize="medium")
+ ax_heatmap.set_ylabel("Images", rotation=90, fontsize="medium")
+
+ # Set title
+ ax_cbar.set_title(dict_metrics[metric], rotation=0, fontsize="x-large")
+
+ # X-Tick labels
+ if dict_models:
+ xticklabels = [dict_models[t.get_text()] for t in ax_heatmap.get_xticklabels()]
+ ax_heatmap.set_xticklabels(
+ xticklabels,
+ rotation=20,
+ verticalalignment="top",
+ horizontalalignment="right",
+ fontsize="small",
+ )
+
+ ax_grid.fig.savefig(
+ output_filename,
+ dpi=dpi,
+ bbox_inches="tight",
+ facecolor="white",
+ transparent=False,
+ )
+ ax_grid.fig.clear()
+ plt.close(ax_grid.fig)
diff --git a/climategan/fid.py b/climategan/fid.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e8fbd0bc5130fb1d0aa5031f060c937b635eca2
--- /dev/null
+++ b/climategan/fid.py
@@ -0,0 +1,561 @@
+# from https://github.com/mseitzer/pytorch-fid
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+from scipy import linalg
+from torch.nn.functional import adaptive_avg_pool2d
+
+try:
+ from torchvision.models.utils import load_state_dict_from_url
+except ImportError:
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
+
+FID_WEIGHTS_URL = (
+ "https://github.com/mseitzer/pytorch-fid/releases/download/"
+ + "fid_weights/pt_inception-2015-12-05-6726825d.pth"
+)
+
+
+class InceptionV3(nn.Module):
+ """Pretrained InceptionV3 network returning feature maps"""
+
+ # Index of default block of inception to return,
+ # corresponds to output of final average pooling
+ DEFAULT_BLOCK_INDEX = 3
+
+ # Maps feature dimensionality to their output blocks indices
+ BLOCK_INDEX_BY_DIM = {
+ 64: 0, # First max pooling features
+ 192: 1, # Second max pooling features
+ 768: 2, # Pre-aux classifier features
+ 2048: 3, # Final average pooling features
+ }
+
+ def __init__(
+ self,
+ output_blocks=[DEFAULT_BLOCK_INDEX],
+ resize_input=True,
+ normalize_input=True,
+ requires_grad=False,
+ use_fid_inception=True,
+ ):
+ """Build pretrained InceptionV3
+ Parameters
+ ----------
+ output_blocks : list of int
+ Indices of blocks to return features of. Possible values are:
+ - 0: corresponds to output of first max pooling
+ - 1: corresponds to output of second max pooling
+ - 2: corresponds to output which is fed to aux classifier
+ - 3: corresponds to output of final average pooling
+ resize_input : bool
+ If true, bilinearly resizes input to width and height 299 before
+ feeding input to model. As the network without fully connected
+ layers is fully convolutional, it should be able to handle inputs
+ of arbitrary size, so resizing might not be strictly needed
+ normalize_input : bool
+ If true, scales the input from range (0, 1) to the range the
+ pretrained Inception network expects, namely (-1, 1)
+ requires_grad : bool
+ If true, parameters of the model require gradients. Possibly useful
+ for finetuning the network
+ use_fid_inception : bool
+ If true, uses the pretrained Inception model used in Tensorflow's
+ FID implementation. If false, uses the pretrained Inception model
+ available in torchvision. The FID Inception model has different
+ weights and a slightly different structure from torchvision's
+ Inception model. If you want to compute FID scores, you are
+ strongly advised to set this parameter to true to get comparable
+ results.
+ """
+ super(InceptionV3, self).__init__()
+
+ self.resize_input = resize_input
+ self.normalize_input = normalize_input
+ self.output_blocks = sorted(output_blocks)
+ self.last_needed_block = max(output_blocks)
+
+ assert self.last_needed_block <= 3, "Last possible output block index is 3"
+
+ self.blocks = nn.ModuleList()
+
+ if use_fid_inception:
+ inception = fid_inception_v3()
+ else:
+ inception = _inception_v3(pretrained=True)
+
+ # Block 0: input to maxpool1
+ block0 = [
+ inception.Conv2d_1a_3x3,
+ inception.Conv2d_2a_3x3,
+ inception.Conv2d_2b_3x3,
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ ]
+ self.blocks.append(nn.Sequential(*block0))
+
+ # Block 1: maxpool1 to maxpool2
+ if self.last_needed_block >= 1:
+ block1 = [
+ inception.Conv2d_3b_1x1,
+ inception.Conv2d_4a_3x3,
+ nn.MaxPool2d(kernel_size=3, stride=2),
+ ]
+ self.blocks.append(nn.Sequential(*block1))
+
+ # Block 2: maxpool2 to aux classifier
+ if self.last_needed_block >= 2:
+ block2 = [
+ inception.Mixed_5b,
+ inception.Mixed_5c,
+ inception.Mixed_5d,
+ inception.Mixed_6a,
+ inception.Mixed_6b,
+ inception.Mixed_6c,
+ inception.Mixed_6d,
+ inception.Mixed_6e,
+ ]
+ self.blocks.append(nn.Sequential(*block2))
+
+ # Block 3: aux classifier to final avgpool
+ if self.last_needed_block >= 3:
+ block3 = [
+ inception.Mixed_7a,
+ inception.Mixed_7b,
+ inception.Mixed_7c,
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
+ ]
+ self.blocks.append(nn.Sequential(*block3))
+
+ for param in self.parameters():
+ param.requires_grad = requires_grad
+
+ def forward(self, inp):
+ """Get Inception feature maps
+ Parameters
+ ----------
+ inp : torch.autograd.Variable
+ Input tensor of shape Bx3xHxW. Values are expected to be in
+ range (0, 1)
+ Returns
+ -------
+ List of torch.autograd.Variable, corresponding to the selected output
+ block, sorted ascending by index
+ """
+ outp = []
+ x = inp
+
+ if self.resize_input:
+ x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=False)
+
+ if self.normalize_input:
+ x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1)
+
+ for idx, block in enumerate(self.blocks):
+ x = block(x)
+ if idx in self.output_blocks:
+ outp.append(x)
+
+ if idx == self.last_needed_block:
+ break
+
+ return outp
+
+
+def _inception_v3(*args, **kwargs):
+ """Wraps `torchvision.models.inception_v3`
+ Skips default weight initialization if supported by torchvision version.
+ See https://github.com/mseitzer/pytorch-fid/issues/28.
+ """
+ try:
+ version = tuple(map(int, torchvision.__version__.split(".")[:2]))
+ except ValueError:
+ # Just a caution against weird version strings
+ version = (0,)
+
+ if version >= (0, 6):
+ kwargs["init_weights"] = False
+
+ return torchvision.models.inception_v3(*args, **kwargs)
+
+
+def fid_inception_v3():
+ """Build pretrained Inception model for FID computation
+ The Inception model for FID computation uses a different set of weights
+ and has a slightly different structure than torchvision's Inception.
+ This method first constructs torchvision's Inception and then patches the
+ necessary parts that are different in the FID Inception model.
+ """
+ inception = _inception_v3(num_classes=1008, aux_logits=False, pretrained=False)
+ inception.Mixed_5b = FIDInceptionA(192, pool_features=32)
+ inception.Mixed_5c = FIDInceptionA(256, pool_features=64)
+ inception.Mixed_5d = FIDInceptionA(288, pool_features=64)
+ inception.Mixed_6b = FIDInceptionC(768, channels_7x7=128)
+ inception.Mixed_6c = FIDInceptionC(768, channels_7x7=160)
+ inception.Mixed_6d = FIDInceptionC(768, channels_7x7=160)
+ inception.Mixed_6e = FIDInceptionC(768, channels_7x7=192)
+ inception.Mixed_7b = FIDInceptionE_1(1280)
+ inception.Mixed_7c = FIDInceptionE_2(2048)
+
+ state_dict = load_state_dict_from_url(FID_WEIGHTS_URL, progress=True)
+ inception.load_state_dict(state_dict)
+ return inception
+
+
+class FIDInceptionA(torchvision.models.inception.InceptionA):
+ """InceptionA block patched for FID computation"""
+
+ def __init__(self, in_channels, pool_features):
+ super(FIDInceptionA, self).__init__(in_channels, pool_features)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch5x5 = self.branch5x5_1(x)
+ branch5x5 = self.branch5x5_2(branch5x5)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
+ )
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionC(torchvision.models.inception.InceptionC):
+ """InceptionC block patched for FID computation"""
+
+ def __init__(self, in_channels, channels_7x7):
+ super(FIDInceptionC, self).__init__(in_channels, channels_7x7)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch7x7 = self.branch7x7_1(x)
+ branch7x7 = self.branch7x7_2(branch7x7)
+ branch7x7 = self.branch7x7_3(branch7x7)
+
+ branch7x7dbl = self.branch7x7dbl_1(x)
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
+ )
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_1(torchvision.models.inception.InceptionE):
+ """First InceptionE block patched for FID computation"""
+
+ def __init__(self, in_channels):
+ super(FIDInceptionE_1, self).__init__(in_channels)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ # Patch: Tensorflow's average pool does not use the padded zero's in
+ # its average calculation
+ branch_pool = F.avg_pool2d(
+ x, kernel_size=3, stride=1, padding=1, count_include_pad=False
+ )
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+class FIDInceptionE_2(torchvision.models.inception.InceptionE):
+ """Second InceptionE block patched for FID computation"""
+
+ def __init__(self, in_channels):
+ super(FIDInceptionE_2, self).__init__(in_channels)
+
+ def forward(self, x):
+ branch1x1 = self.branch1x1(x)
+
+ branch3x3 = self.branch3x3_1(x)
+ branch3x3 = [
+ self.branch3x3_2a(branch3x3),
+ self.branch3x3_2b(branch3x3),
+ ]
+ branch3x3 = torch.cat(branch3x3, 1)
+
+ branch3x3dbl = self.branch3x3dbl_1(x)
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
+ branch3x3dbl = [
+ self.branch3x3dbl_3a(branch3x3dbl),
+ self.branch3x3dbl_3b(branch3x3dbl),
+ ]
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
+
+ # Patch: The FID Inception model uses max pooling instead of average
+ # pooling. This is likely an error in this specific Inception
+ # implementation, as other Inception models use average pooling here
+ # (which matches the description in the paper).
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
+ branch_pool = self.branch_pool(branch_pool)
+
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
+ return torch.cat(outputs, 1)
+
+
+def compute_val_fid(trainer, verbose=0):
+ """
+ Compute the fid score between the n=opts.train.fid.n_images real images
+ from the validation set (domain is rf) and n fake images pained from
+ those n validation images
+
+ Args:
+ trainer (climategan.Trainer): trainer to compute the val fid for
+
+ Returns:
+ float: FID score
+ """
+ # get opts params
+ batch_size = trainer.opts.train.fid.get("batch_size", 50)
+ dims = trainer.opts.train.fid.get("dims", 2048)
+
+ # set inception model
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
+ model = InceptionV3([block_idx]).to(trainer.device)
+
+ # first fid computation: compute the real stats, only once
+ if trainer.real_val_fid_stats is None:
+ if verbose > 0:
+ print("Computing real_val_fid_stats for the first time")
+ set_real_val_fid_stats(trainer, model, batch_size, dims)
+
+ # get real stats
+ real_m = trainer.real_val_fid_stats["m"]
+ real_s = trainer.real_val_fid_stats["s"]
+
+ # compute fake images
+ fakes = compute_fakes(trainer)
+ if verbose > 0:
+ print("Computing fake activation statistics")
+ # get fake stats
+ fake_m, fake_s = calculate_activation_statistics(
+ fakes, model, batch_size=batch_size, dims=dims, device=trainer.device
+ )
+ # compute FD between the real and the fake inception stats
+ return calculate_frechet_distance(real_m, real_s, fake_m, fake_s)
+
+
+def set_real_val_fid_stats(trainer, model, batch_size, dims):
+ """
+ Sets the real_val_fid_stats attribute of the trainer with the m and
+ s outputs of calculate_activation_statistics on the real data.
+
+ This needs to be done only once since nothing changes during training here.
+
+ Args:
+ trainer (climategan.Trainer): trainer instance to compute the stats for
+ model (InceptionV3): inception model to get the activations from
+ batch_size (int): inception inference batch size
+ dims (int): dimension selected in the model
+ """
+ # in the rf domain display_size may be different from fid.n_images
+ limit = trainer.opts.train.fid.n_images
+ display_x = torch.stack(
+ [sample["data"]["x"] for sample in trainer.display_images["val"]["rf"][:limit]]
+ ).to(trainer.device)
+ m, s = calculate_activation_statistics(
+ display_x, model, batch_size=batch_size, dims=dims, device=trainer.device
+ )
+ trainer.real_val_fid_stats = {"m": m, "s": s}
+
+
+def compute_fakes(trainer, verbose=0):
+ """
+ Compute current fake inferences
+
+ Args:
+ trainer (climategan.Trainer): trainer instance
+ verbose (int, optional): Print level. Defaults to 0.
+
+ Returns:
+ torch.Tensor: trainer.opts.train.fid.n_images painted images
+ """
+ # in the rf domain display_size may be different from fid.n_images
+ n = trainer.opts.train.fid.n_images
+ bs = trainer.opts.data.loaders.batch_size
+
+ display_batches = [
+ (sample["data"]["x"], sample["data"]["m"])
+ for sample in trainer.display_images["val"]["rf"][:n]
+ ]
+
+ display_x = torch.stack([b[0] for b in display_batches]).to(trainer.device)
+ display_m = torch.stack([b[0] for b in display_batches]).to(trainer.device)
+ nbs = len(display_x) // bs + 1
+
+ fakes = []
+ for b in range(nbs):
+ if verbose > 0:
+ print("computing fakes {}/{}".format(b + 1, nbs), end="\r", flush=True)
+ with torch.no_grad():
+ x = display_x[b * bs : (b + 1) * bs]
+ m = display_m[b * bs : (b + 1) * bs]
+ fake = trainer.G.paint(m, x)
+ fakes.append(fake)
+
+ return torch.cat(fakes, dim=0)
+
+
+def calculate_activation_statistics(
+ images, model, batch_size=50, dims=2048, device="cpu"
+):
+ """Calculation of the statistics used by the FID.
+ Params:
+ -- images : List of images
+ -- model : Instance of inception model
+ -- batch_size : The images numpy array is split into batches with
+ batch size batch_size. A reasonable batch size
+ depends on the hardware.
+ -- dims : Dimensionality of features returned by Inception
+ -- device : Device to run calculations
+ Returns:
+ -- mu : The mean over samples of the activations of the pool_3 layer of
+ the inception model.
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
+ the inception model.
+ """
+ act = get_activations(images, model, batch_size, dims, device)
+ mu = np.mean(act, axis=0)
+ sigma = np.cov(act, rowvar=False)
+ return mu, sigma
+
+
+def get_activations(images, model, batch_size=50, dims=2048, device="cpu"):
+ """Calculates the activations of the pool_3 layer for all images.
+ Params:
+ -- images : List of images
+ -- model : Instance of inception model
+ -- batch_size : Batch size of images for the model to process at once.
+ Make sure that the number of samples is a multiple of
+ the batch size, otherwise some samples are ignored. This
+ behavior is retained to match the original FID score
+ implementation.
+ -- dims : Dimensionality of features returned by Inception
+ -- device : Device to run calculations
+ Returns:
+ -- A numpy array of dimension (num images, dims) that contains the
+ activations of the given tensor when feeding inception with the
+ query tensor.
+ """
+ model.eval()
+
+ pred_arr = np.empty((len(images), dims))
+
+ start_idx = 0
+ nbs = len(images) // batch_size + 1
+
+ for b in range(nbs):
+ batch = images[b * batch_size : (b + 1) * batch_size].to(device)
+ if not batch.nelement():
+ continue
+
+ with torch.no_grad():
+ pred = model(batch)[0]
+
+ # If model output is not scalar, apply global spatial average pooling.
+ # This happens if you choose a dimensionality not equal 2048.
+ if pred.size(2) != 1 or pred.size(3) != 1:
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
+
+ pred = pred.squeeze(3).squeeze(2).cpu().numpy()
+
+ pred_arr[start_idx : start_idx + pred.shape[0]] = pred
+
+ start_idx = start_idx + pred.shape[0]
+
+ return pred_arr
+
+
+def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
+ """Numpy implementation of the Frechet Distance.
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+ and X_2 ~ N(mu_2, C_2) is
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+ Stable version by Dougal J. Sutherland.
+ Params:
+ -- mu1 : Numpy array containing the activations of a layer of the
+ inception net (like returned by the function 'get_predictions')
+ for generated samples.
+ -- mu2 : The sample mean over activations, precalculated on an
+ representative data set.
+ -- sigma1: The covariance matrix over activations for generated samples.
+ -- sigma2: The covariance matrix over activations, precalculated on an
+ representative data set.
+ Returns:
+ -- : The Frechet Distance.
+ """
+
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+
+ assert (
+ mu1.shape == mu2.shape
+ ), "Training and test mean vectors have different lengths"
+ assert (
+ sigma1.shape == sigma2.shape
+ ), "Training and test covariances have different dimensions"
+
+ diff = mu1 - mu2
+
+ # Product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ msg = (
+ "fid calculation produces singular product; "
+ "adding %s to diagonal of cov estimates"
+ ) % eps
+ print(msg)
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # Numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError("Imaginary component {}".format(m))
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
diff --git a/climategan/fire.py b/climategan/fire.py
new file mode 100644
index 0000000000000000000000000000000000000000..0181e47bc8848627244abeca689f8dc5d1132d74
--- /dev/null
+++ b/climategan/fire.py
@@ -0,0 +1,133 @@
+import torch
+import torch.nn.functional as F
+import random
+import kornia
+from torchvision.transforms.functional import adjust_brightness, adjust_contrast
+
+from climategan.tutils import normalize, retrieve_sky_mask
+
+try:
+ from kornia.filters import filter2d
+except ImportError:
+ from kornia.filters import filter2D as filter2d
+
+
+def increase_sky_mask(mask, p_w=0, p_h=0):
+ """
+ Increases sky mask in width and height by a given pourcentage
+ (Purpose: when applying Gaussian blur, there are no artifacts of blue sky behind)
+ Args:
+ sky_mask (torch.Tensor): Sky mask of shape (H,W)
+ p_w (float): Percentage of mask width by which to increase
+ the width of the sky region
+ p_h (float): Percentage of mask height by which to increase
+ the height of the sky region
+ Returns:
+ torch.Tensor: Sky mask increased given p_w and p_h
+ """
+
+ if p_h <= 0 and p_w <= 0:
+ return mask
+
+ n_lines = int(p_h * mask.shape[-2])
+ n_cols = int(p_w * mask.shape[-1])
+
+ temp_mask = mask.clone().detach()
+ for i in range(1, n_cols):
+ temp_mask[:, :, :, i::] += mask[:, :, :, 0:-i]
+ temp_mask[:, :, :, 0:-i] += mask[:, :, :, i::]
+
+ new_mask = temp_mask.clone().detach()
+ for i in range(1, n_lines):
+ new_mask[:, :, i::, :] += temp_mask[:, :, 0:-i, :]
+ new_mask[:, :, 0:-i, :] += temp_mask[:, :, i::, :]
+
+ new_mask[new_mask >= 1] = 1
+
+ return new_mask
+
+
+def paste_filter(x, filter_, mask):
+ """
+ Pastes a filter over an image given a mask
+ Where the mask is 1, the filter is copied as is.
+ Where the mask is 0, the current value is preserved.
+ Intermediate values will mix the two images together.
+ Args:
+ x (torch.Tensor): Input tensor, range must be [0, 255]
+ filer_ (torch.Tensor): Filter, range must be [0, 255]
+ mask (torch.Tensor): Mask, range must be [0, 1]
+ Returns:
+ torch.Tensor: New tensor with filter pasted on it
+ """
+ assert len(x.shape) == len(filter_.shape) == len(mask.shape)
+ x = filter_ * mask + x * (1 - mask)
+ return x
+
+
+def add_fire(x, seg_preds, fire_opts):
+ """
+ Transforms input tensor given wildfires event
+ Args:
+ x (torch.Tensor): Input tensor
+ seg_preds (torch.Tensor): Semantic segmentation predictions for input tensor
+ filter_color (tuple): (r,g,b) tuple for the color of the sky
+ blur_radius (float): radius of the Gaussian blur that smooths
+ the transition between sky and foreground
+ Returns:
+ torch.Tensor: Wildfire version of input tensor
+ """
+ wildfire_tens = normalize(x, 0, 255)
+
+ # Warm the image
+ wildfire_tens[:, 2, :, :] -= 20
+ wildfire_tens[:, 1, :, :] -= 10
+ wildfire_tens[:, 0, :, :] += 40
+ wildfire_tens.clamp_(0, 255)
+ wildfire_tens = wildfire_tens.to(torch.uint8)
+
+ # Darken the picture and increase contrast
+ wildfire_tens = adjust_contrast(wildfire_tens, contrast_factor=1.5)
+ wildfire_tens = adjust_brightness(wildfire_tens, brightness_factor=0.73)
+
+ sky_mask = retrieve_sky_mask(seg_preds).unsqueeze(1)
+
+ if fire_opts.get("crop_bottom_sky_mask"):
+ i = 2 * sky_mask.shape[-2] // 3
+ sky_mask[..., i:, :] = 0
+
+ sky_mask = F.interpolate(
+ sky_mask.to(torch.float),
+ (wildfire_tens.shape[-2], wildfire_tens.shape[-1]),
+ )
+ sky_mask = increase_sky_mask(sky_mask, 0.18, 0.18)
+
+ kernel_size = (fire_opts.get("kernel_size", 301), fire_opts.get("kernel_size", 301))
+ sigma = (fire_opts.get("kernel_sigma", 150.5), fire_opts.get("kernel_sigma", 150.5))
+ border_type = "reflect"
+ kernel = torch.unsqueeze(
+ kornia.filters.kernels.get_gaussian_kernel2d(kernel_size, sigma), dim=0
+ ).to(x.device)
+ sky_mask = filter2d(sky_mask, kernel, border_type)
+
+ filter_ = torch.ones(wildfire_tens.shape, device=x.device)
+ filter_[:, 0, :, :] = 255
+ filter_[:, 1, :, :] = random.randint(100, 150)
+ filter_[:, 2, :, :] = 0
+
+ wildfire_tens = paste_tensor(wildfire_tens, filter_, sky_mask, 200)
+
+ wildfire_tens = adjust_brightness(wildfire_tens.to(torch.uint8), 0.8)
+ wildfire_tens = wildfire_tens.to(torch.float)
+
+ # dummy pixels to fool scaling and preserve range
+ wildfire_tens[:, :, 0, 0] = 255.0
+ wildfire_tens[:, :, -1, -1] = 0.0
+
+ return wildfire_tens
+
+
+def paste_tensor(source, filter_, mask, transparency):
+ mask = transparency / 255.0 * mask
+ new = mask * filter_ + (1.0 - mask) * source
+ return new
diff --git a/climategan/generator.py b/climategan/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ed7d42fee6e15bcc4bbd657f897843e82bb530b
--- /dev/null
+++ b/climategan/generator.py
@@ -0,0 +1,415 @@
+"""Complete Generator architecture:
+ * OmniGenerator
+ * Encoder
+ * Decoders
+"""
+from pathlib import Path
+import traceback
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import yaml
+from addict import Dict
+from torch import softmax
+
+import climategan.strings as strings
+from climategan.deeplab import create_encoder, create_segmentation_decoder
+from climategan.depth import create_depth_decoder
+from climategan.masker import create_mask_decoder
+from climategan.painter import create_painter
+from climategan.tutils import init_weights, mix_noise, normalize
+
+
+def create_generator(opts, device="cpu", latent_shape=None, no_init=False, verbose=0):
+ G = OmniGenerator(opts, latent_shape, verbose, no_init)
+ if no_init:
+ print("Sending to", device)
+ return G.to(device)
+
+ for model in G.decoders:
+ net = G.decoders[model]
+ if model == "s":
+ continue
+ if isinstance(net, nn.ModuleDict):
+ for domain, domain_model in net.items():
+ init_weights(
+ net[domain_model],
+ init_type=opts.gen[model].init_type,
+ init_gain=opts.gen[model].init_gain,
+ verbose=verbose,
+ caller=f"create_generator decoder {model} {domain}",
+ )
+ else:
+ init_weights(
+ G.decoders[model],
+ init_type=opts.gen[model].init_type,
+ init_gain=opts.gen[model].init_gain,
+ verbose=verbose,
+ caller=f"create_generator decoder {model}",
+ )
+ if G.encoder is not None and opts.gen.encoder.architecture == "base":
+ init_weights(
+ G.encoder,
+ init_type=opts.gen.encoder.init_type,
+ init_gain=opts.gen.encoder.init_gain,
+ verbose=verbose,
+ caller="create_generator encoder",
+ )
+
+ print("Sending to", device)
+ return G.to(device)
+
+
+class OmniGenerator(nn.Module):
+ def __init__(self, opts, latent_shape=None, verbose=0, no_init=False):
+ """Creates the generator. All decoders listed in opts.gen will be added
+ to the Generator.decoders ModuleDict if opts.gen.DecoderInitial is not True.
+ Then can be accessed as G.decoders.T or G.decoders["T"] for instance,
+ for the image Translation decoder
+
+ Args:
+ opts (addict.Dict): configuration dict
+ """
+ super().__init__()
+ self.opts = opts
+ self.verbose = verbose
+ self.encoder = None
+ if any(t in opts.tasks for t in "msd"):
+ self.encoder = create_encoder(opts, no_init, verbose)
+
+ self.decoders = {}
+ self.painter = nn.Module()
+
+ if "d" in opts.tasks:
+ self.decoders["d"] = create_depth_decoder(opts, no_init, verbose)
+
+ if self.verbose > 0:
+ print(f" - Add {self.decoders['d'].__class__.__name__}")
+
+ if "s" in opts.tasks:
+ self.decoders["s"] = create_segmentation_decoder(opts, no_init, verbose)
+
+ if "m" in opts.tasks:
+ self.decoders["m"] = create_mask_decoder(opts, no_init, verbose)
+
+ self.decoders = nn.ModuleDict(self.decoders)
+
+ if "p" in self.opts.tasks:
+ self.painter = create_painter(opts, no_init, verbose)
+ else:
+ if self.verbose > 0:
+ print(" - Add Empty Painter")
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
+
+ def __str__(self):
+ return strings.generator(self)
+
+ def encode(self, x):
+ """
+ Forward x through the encoder
+
+ Args:
+ x (torch.Tensor): B3HW input tensor
+
+ Returns:
+ list: High and Low level features from the encoder
+ """
+ assert self.encoder is not None
+ return self.encoder.forward(x)
+
+ def decode(self, x=None, z=None, return_z=False, return_z_depth=False):
+ """
+ Comptutes the predictions of all available decoders from either x or z.
+ If using spade for the masker with 15 channels, x *must* be provided,
+ whether z is too or not.
+
+ Args:
+ x (torch.Tensor, optional): Input tensor (B3HW). Defaults to None.
+ z (list, optional): List of high and low-level features as BCHW.
+ Defaults to None.
+ return_z (bool, optional): whether or not to return z in the dict.
+ Defaults to False.
+ return_z_depth (bool, optional): whether or not to return z_depth
+ in the dict. Defaults to False.
+
+ Raises:
+ ValueError: If using spade for the masker with 15 channels but x is None
+
+ Returns:
+ dict: {task: prediction_tensor} (may include z and z_depth
+ depending on args)
+ """
+
+ assert x is not None or z is not None
+ if self.opts.gen.m.use_spade and self.opts.m.spade.cond_nc == 15:
+ if x is None:
+ raise ValueError(
+ "When using spade for the Masker with 15 channels,"
+ + " x MUST be provided"
+ )
+
+ z_depth = cond = d = s = None
+ out = {}
+
+ if z is None:
+ z = self.encode(x)
+
+ if return_z:
+ out["z"] = z
+
+ if "d" in self.decoders:
+ d, z_depth = self.decoders["d"](z)
+ out["d"] = d
+
+ if return_z_depth:
+ out["z_depth"] = z_depth
+
+ if "s" in self.decoders:
+ s = self.decoders["s"](z, z_depth)
+ out["s"] = s
+
+ if "m" in self.decoders:
+ if s is not None and d is not None:
+ cond = self.make_m_cond(d, s, x)
+ m = self.mask(z=z, cond=cond)
+ out["m"] = m
+
+ return out
+
+ def sample_painter_z(self, batch_size, device, force_half=False):
+ if self.opts.gen.p.no_z:
+ return None
+
+ z = torch.empty(
+ batch_size,
+ self.opts.gen.p.latent_dim,
+ self.painter.z_h,
+ self.painter.z_w,
+ device=device,
+ ).normal_(mean=0, std=1.0)
+
+ if force_half:
+ z = z.half()
+
+ return z
+
+ def make_m_cond(self, d, s, x=None):
+ """
+ Create the masker's conditioning input when using spade from the
+ d and s predictions and from the input x when cond_nc == 15.
+
+ d and s are assumed to have the the same spatial resolution.
+ if cond_nc == 15 then x is interpolated to match that dimension.
+
+ Args:
+ d (torch.Tensor): Raw depth prediction (B1HW)
+ s (torch.Tensor): Raw segmentation prediction (BCHW)
+ x (torch.Tensor, optional): Input tensor (B3hW). Mandatory
+ when opts.gen.m.spade.cond_nc == 15
+
+ Raises:
+ ValueError: opts.gen.m.spade.cond_nc == 15 but x is None
+
+ Returns:
+ torch.Tensor: B x cond_nc x H x W conditioning tensor.
+ """
+ if self.opts.gen.m.spade.detach:
+ d = d.detach()
+ s = s.detach()
+ cats = [normalize(d), softmax(s, dim=1)]
+ if self.opts.gen.m.spade.cond_nc == 15:
+ if x is None:
+ raise ValueError(
+ "When using spade for the Masker with 15 channels,"
+ + " x MUST be provided"
+ )
+ cats += [
+ F.interpolate(x, s.shape[-2:], mode="bilinear", align_corners=True)
+ ]
+
+ return torch.cat(cats, dim=1)
+
+ def mask(self, x=None, z=None, cond=None, z_depth=None, sigmoid=True):
+ """
+ Create a mask from either an input x or a latent vector z.
+ Optionally if the Masker has a spade architecture the conditioning tensor
+ may be provided (cond). Default behavior applies an element-wise
+ sigmoid, but can be deactivated (sigmoid=False).
+
+ At least one of x or z must be provided (i.e. not None).
+ If the Masker has a spade architecture and cond_nc == 15 then x cannot
+ be None.
+
+ Args:
+ x (torch.Tensor, optional): Input tensor B3HW. Defaults to None.
+ z (list, optional): High and Low level features of the encoder.
+ Will be computed if None. Defaults to None.
+ cond ([type], optional): [description]. Defaults to None.
+ sigmoid (bool, optional): [description]. Defaults to True.
+
+ Returns:
+ torch.Tensor: B1HW mask tensor
+ """
+ assert x is not None or z is not None
+ if z is None:
+ z = self.encode(x)
+
+ if cond is None and self.opts.gen.m.use_spade:
+ assert "s" in self.opts.tasks and "d" in self.opts.tasks
+ with torch.no_grad():
+ d_pred, z_d = self.decoders["d"](z)
+ s_pred = self.decoders["s"](z, z_d)
+ cond = self.make_m_cond(d_pred, s_pred, x)
+ if z_depth is None and self.opts.gen.m.use_dada:
+ assert "d" in self.opts.tasks
+ with torch.no_grad():
+ _, z_depth = self.decoders["d"](z)
+
+ if cond is not None:
+ device = z[0].device if isinstance(z, (tuple, list)) else z.device
+ cond = cond.to(device)
+
+ logits = self.decoders["m"](z, cond, z_depth)
+
+ if not sigmoid:
+ return logits
+
+ return torch.sigmoid(logits)
+
+ def paint(self, m, x, no_paste=False):
+ """
+ Paints given a mask and an image
+ calls painter(z, x * (1.0 - m))
+ Mask has 1s where water should be painted
+
+ Args:
+ m (torch.Tensor): Mask
+ x (torch.Tensor): Image to paint
+
+ Returns:
+ torch.Tensor: painted image
+ """
+ z_paint = self.sample_painter_z(x.shape[0], x.device)
+ m = m.to(x.dtype)
+ fake = self.painter(z_paint, x * (1.0 - m))
+ if self.opts.gen.p.paste_original_content and not no_paste:
+ return x * (1.0 - m) + fake * m
+ return fake
+
+ def paint_cloudy(self, m, x, s, sky_idx=9, res=(8, 8), weight=0.8):
+ """
+ Paints x with water in m through an intermediary cloudy image
+ where the sky has been replaced with perlin noise to imitate clouds.
+
+ The intermediary cloudy image is only used to control the painter's
+ painting mode, probing it with a cloudy input.
+
+ Args:
+ m (torch.Tensor): water mask
+ x (torch.Tensor): input tensor
+ s (torch.Tensor): segmentation prediction (BCHW)
+ sky_idx (int, optional): Index of the sky class along s's C dimension.
+ Defaults to 9.
+ res (tuple, optional): Perlin noise spatial resolution. Defaults to (8, 8).
+ weight (float, optional): Intermediate image's cloud proportion
+ (w * cloud + (1-w) * original_sky). Defaults to 0.8.
+
+ Returns:
+ torch.Tensor: painted image with original content pasted.
+ """
+ sky_mask = (
+ torch.argmax(
+ F.interpolate(s, x.shape[-2:], mode="bilinear"), dim=1, keepdim=True
+ )
+ == sky_idx
+ ).to(x.dtype)
+ noised_x = mix_noise(x, sky_mask, res=res, weight=weight).to(x.dtype)
+ fake = self.paint(m, noised_x, no_paste=True)
+ return x * (1.0 - m) + fake * m
+
+ def depth(self, x=None, z=None, return_z_depth=False):
+ """
+ Compute the depth head's output
+
+ Args:
+ x (torch.Tensor, optional): Input B3HW tensor. Defaults to None.
+ z (list, optional): High and Low level features of the encoder.
+ Defaults to None.
+
+ Returns:
+ torch.Tensor: B1HW tensor of depth predictions
+ """
+ assert x is not None or z is not None
+ assert not (x is not None and z is not None)
+ if z is None:
+ z = self.encode(x)
+ depth, z_depth = self.decoders["d"](z)
+
+ if depth.shape[1] > 1:
+ depth = torch.argmax(depth, dim=1)
+ depth = depth / depth.max()
+
+ if return_z_depth:
+ return depth, z_depth
+
+ return depth
+
+ def load_val_painter(self):
+ """
+ Loads a validation painter if available in opts.val.val_painter
+
+ Returns:
+ bool: operation success status
+ """
+ try:
+ # key exists in opts
+ assert self.opts.val.val_painter
+
+ # path exists
+ ckpt_path = Path(self.opts.val.val_painter).resolve()
+ assert ckpt_path.exists()
+
+ # path is a checkpoint path
+ assert ckpt_path.is_file()
+
+ # opts are available in that path
+ opts_path = ckpt_path.parent.parent / "opts.yaml"
+ assert opts_path.exists()
+
+ # load opts
+ with opts_path.open("r") as f:
+ val_painter_opts = Dict(yaml.safe_load(f))
+
+ # load checkpoint
+ state_dict = torch.load(ckpt_path, map_location=self.device)
+
+ # create dummy painter from loaded opts
+ painter = create_painter(val_painter_opts)
+
+ # load state-dict in the dummy painter
+ painter.load_state_dict(
+ {k.replace("painter.", ""): v for k, v in state_dict["G"].items()}
+ )
+
+ # send to current device in evaluation mode
+ device = next(self.parameters()).device
+ self.painter = painter.eval().to(device)
+
+ # disable gradients
+ for p in self.painter.parameters():
+ p.requires_grad = False
+
+ # success
+ print(" - Loaded validation-only painter")
+ return True
+
+ except Exception as e:
+ # something happened, aborting gracefully
+ print(traceback.format_exc())
+ print(e)
+ print(">>> WARNING: error (^) in load_val_painter, aborting.")
+ return False
diff --git a/climategan/logger.py b/climategan/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..a20023e453fbfb8cbb6c351d34ea348ac25117e0
--- /dev/null
+++ b/climategan/logger.py
@@ -0,0 +1,445 @@
+from pathlib import Path
+
+import numpy as np
+import torch
+import torchvision.utils as vutils
+from addict import Dict
+from PIL import Image
+from torch.nn.functional import interpolate, sigmoid
+
+from climategan.data import decode_segmap_merged_labels
+from climategan.tutils import (
+ all_texts_to_tensors,
+ decode_bucketed_depth,
+ normalize_tensor,
+ write_architecture,
+)
+from climategan.utils import flatten_opts
+
+
+class Logger:
+ def __init__(self, trainer):
+ self.losses = Dict()
+ self.time = Dict()
+ self.trainer = trainer
+ self.global_step = 0
+ self.epoch = 0
+
+ def log_comet_images(self, mode, domain, minimal=False, all_only=False):
+ trainer = self.trainer
+ save_images = {}
+ all_images = []
+ n_all_ims = None
+ all_legends = ["Input"]
+ task_legends = {}
+
+ if domain not in trainer.display_images[mode]:
+ return
+
+ # --------------------
+ # ----- Masker -----
+ # --------------------
+ n_ims = len(trainer.display_images[mode][domain])
+ print(" " * 60, end="\r")
+ if domain != "rf":
+ for j, display_dict in enumerate(trainer.display_images[mode][domain]):
+
+ print(f"Inferring sample {mode} {domain} {j+1}/{n_ims}", end="\r")
+
+ x = display_dict["data"]["x"].unsqueeze(0).to(trainer.device)
+ z = trainer.G.encode(x)
+
+ s_pred = decoded_s_pred = d_pred = z_depth = None
+ for k, task in enumerate(["d", "s", "m"]):
+
+ if (
+ task not in display_dict["data"]
+ or task not in trainer.opts.tasks
+ ):
+ continue
+
+ task_legend = ["Input"]
+ target = display_dict["data"][task]
+ target = target.unsqueeze(0).to(trainer.device)
+ task_saves = []
+
+ if task not in save_images:
+ save_images[task] = []
+
+ prediction = None
+ if task == "m":
+ cond = None
+ if s_pred is not None and d_pred is not None:
+ cond = trainer.G.make_m_cond(d_pred, s_pred, x)
+
+ prediction = trainer.G.decoders[task](z, cond, z_depth)
+ elif task == "d":
+ prediction, z_depth = trainer.G.decoders[task](z)
+ elif task == "s":
+ prediction = trainer.G.decoders[task](z, z_depth)
+
+ if task == "s":
+ # Log fire
+ wildfire_tens = trainer.compute_fire(x, prediction)
+ task_saves.append(wildfire_tens)
+ task_legend.append("Wildfire")
+ # Log seg output
+ s_pred = prediction.clone()
+ target = (
+ decode_segmap_merged_labels(target, domain, True)
+ .float()
+ .to(trainer.device)
+ )
+ prediction = (
+ decode_segmap_merged_labels(prediction, domain, False)
+ .float()
+ .to(trainer.device)
+ )
+ decoded_s_pred = prediction
+ task_saves.append(target)
+ task_legend.append("Target Segmentation")
+
+ elif task == "m":
+ prediction = sigmoid(prediction).repeat(1, 3, 1, 1)
+ task_saves.append(x * (1.0 - prediction))
+ if not minimal:
+ task_saves.append(
+ x * (1.0 - (prediction > 0.1).to(torch.int))
+ )
+ task_saves.append(
+ x * (1.0 - (prediction > 0.5).to(torch.int))
+ )
+
+ task_saves.append(x * (1.0 - target.repeat(1, 3, 1, 1)))
+ task_legend.append("Masked input")
+
+ if not minimal:
+ task_legend.append("Masked input (>0.1)")
+ task_legend.append("Masked input (>0.5)")
+
+ task_legend.append("Masked input (target)")
+ # dummy pixels to fool scaling and preserve mask range
+ prediction[:, :, 0, 0] = 1.0
+ prediction[:, :, -1, -1] = 0.0
+
+ elif task == "d":
+ # prediction is a log depth tensor
+ d_pred = prediction
+ target = normalize_tensor(target) * 255
+ if prediction.shape[1] > 1:
+ prediction = decode_bucketed_depth(
+ prediction, self.trainer.opts
+ )
+ smogged = self.trainer.compute_smog(
+ x, d=prediction, s=decoded_s_pred, use_sky_seg=False
+ )
+ prediction = normalize_tensor(prediction)
+ prediction = prediction.repeat(1, 3, 1, 1)
+ task_saves.append(smogged)
+ task_legend.append("Smogged")
+ task_saves.append(target.repeat(1, 3, 1, 1))
+ task_legend.append("Depth target")
+
+ task_saves.append(prediction)
+ task_legend.append(f"Predicted {task}")
+
+ save_images[task].append(x.cpu().detach())
+ if k == 0:
+ all_images.append(save_images[task][-1])
+
+ task_legends[task] = task_legend
+ if j == 0:
+ all_legends += task_legend[1:]
+
+ for im in task_saves:
+ save_images[task].append(im.cpu().detach())
+ all_images.append(save_images[task][-1])
+
+ if j == 0:
+ n_all_ims = len(all_images)
+
+ if not all_only:
+ for task in save_images.keys():
+ # Write images:
+ self.upload_images(
+ image_outputs=save_images[task],
+ mode=mode,
+ domain=domain,
+ task=task,
+ im_per_row=trainer.opts.comet.im_per_row.get(task, 4),
+ rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
+ legends=task_legends[task],
+ )
+
+ if len(save_images) > 1:
+ self.upload_images(
+ image_outputs=all_images,
+ mode=mode,
+ domain=domain,
+ task="all",
+ im_per_row=n_all_ims,
+ rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
+ legends=all_legends,
+ )
+ # ---------------------
+ # ----- Painter -----
+ # ---------------------
+ else:
+ # in the rf domain display_size may be different from fid.n_images
+ limit = trainer.opts.comet.display_size
+ image_outputs = []
+ legends = []
+ for im_set in trainer.display_images[mode][domain][:limit]:
+ x = im_set["data"]["x"].unsqueeze(0).to(trainer.device)
+ m = im_set["data"]["m"].unsqueeze(0).to(trainer.device)
+
+ prediction = trainer.G.paint(m, x)
+
+ image_outputs.append(x * (1.0 - m))
+ image_outputs.append(prediction)
+ image_outputs.append(x)
+ image_outputs.append(prediction * m)
+ if not legends:
+ legends.append("Masked Input")
+ legends.append("Painted Input")
+ legends.append("Input")
+ legends.append("Isolated Water")
+ # Write images
+ self.upload_images(
+ image_outputs=image_outputs,
+ mode=mode,
+ domain=domain,
+ task="painter",
+ im_per_row=trainer.opts.comet.im_per_row.get("p", 4),
+ rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
+ legends=legends,
+ )
+
+ return 0
+
+ def log_losses(self, model_to_update="G", mode="train"):
+ """Logs metrics on comet.ml
+
+ Args:
+ model_to_update (str, optional): One of "G", "D". Defaults to "G".
+ """
+ trainer = self.trainer
+ loss_names = {"G": "gen", "D": "disc"}
+
+ if trainer.opts.train.log_level < 1:
+ return
+
+ if trainer.exp is None:
+ return
+
+ assert model_to_update in {
+ "G",
+ "D",
+ }, "unknown model to log losses {}".format(model_to_update)
+
+ loss_to_update = self.losses[loss_names[model_to_update]]
+
+ losses = loss_to_update.copy()
+
+ if trainer.opts.train.log_level == 1:
+ # Only log aggregated losses: delete other keys in losses
+ for k in loss_to_update:
+ if k not in {"masker", "total_loss", "painter"}:
+ del losses[k]
+ # convert losses into a single-level dictionnary
+
+ losses = flatten_opts(losses)
+ trainer.exp.log_metrics(
+ losses, prefix=f"{model_to_update}_{mode}", step=self.global_step
+ )
+
+ def log_learning_rates(self):
+ if self.trainer.exp is None:
+ return
+ lrs = {}
+ trainer = self.trainer
+ if trainer.g_scheduler is not None:
+ for name, lr in zip(
+ trainer.lr_names["G"], trainer.g_scheduler.get_last_lr()
+ ):
+ lrs[f"lr_G_{name}"] = lr
+ if trainer.d_scheduler is not None:
+ for name, lr in zip(
+ trainer.lr_names["D"], trainer.d_scheduler.get_last_lr()
+ ):
+ lrs[f"lr_D_{name}"] = lr
+
+ trainer.exp.log_metrics(lrs, step=self.global_step)
+
+ def log_step_time(self, time):
+ """Logs step-time on comet.ml
+
+ Args:
+ step_time (float): step-time in seconds
+ """
+ if self.trainer.exp:
+ self.trainer.exp.log_metric(
+ "step-time", time - self.time.step_start, step=self.global_step
+ )
+
+ def log_epoch_time(self, time):
+ """Logs step-time on comet.ml
+
+ Args:
+ step_time (float): step-time in seconds
+ """
+ if self.trainer.exp:
+ self.trainer.exp.log_metric(
+ "epoch-time", time - self.time.epoch_start, step=self.global_step
+ )
+
+ def log_comet_combined_images(self, mode, domain):
+
+ trainer = self.trainer
+ image_outputs = []
+ legends = []
+ im_per_row = 0
+ for i, im_set in enumerate(trainer.display_images[mode][domain]):
+ x = im_set["data"]["x"].unsqueeze(0).to(trainer.device)
+ # m = im_set["data"]["m"].unsqueeze(0).to(trainer.device)
+
+ m = trainer.G.mask(x=x)
+ m_bin = (m > 0.5).to(m.dtype)
+ prediction = trainer.G.paint(m, x)
+ prediction_bin = trainer.G.paint(m_bin, x)
+
+ image_outputs.append(x)
+ legends.append("Input")
+ image_outputs.append(x * (1.0 - m))
+ legends.append("Soft Masked Input")
+ image_outputs.append(prediction)
+ legends.append("Painted")
+ image_outputs.append(prediction * m)
+ legends.append("Soft Masked Painted")
+ image_outputs.append(x * (1.0 - m_bin))
+ legends.append("Binary (0.5) Masked Input")
+ image_outputs.append(prediction_bin)
+ legends.append("Binary (0.5) Painted")
+ image_outputs.append(prediction_bin * m_bin)
+ legends.append("Binary (0.5) Masked Painted")
+
+ if i == 0:
+ im_per_row = len(image_outputs)
+ # Upload images
+ self.upload_images(
+ image_outputs=image_outputs,
+ mode=mode,
+ domain=domain,
+ task="combined",
+ im_per_row=im_per_row or 7,
+ rows_per_log=trainer.opts.comet.get("rows_per_log", 5),
+ legends=legends,
+ )
+
+ return 0
+
+ def upload_images(
+ self,
+ image_outputs,
+ mode,
+ domain,
+ task,
+ im_per_row=3,
+ rows_per_log=5,
+ legends=[],
+ ):
+ """
+ Save output image
+
+ Args:
+ image_outputs (list(torch.Tensor)): all the images to log
+ mode (str): train or val
+ domain (str): current domain
+ task (str): current task
+ im_per_row (int, optional): umber of images to be displayed per row.
+ Typically, for a given task: 3 because [input prediction, target].
+ Defaults to 3.
+ rows_per_log (int, optional): Number of rows (=samples) per uploaded image.
+ Defaults to 5.
+ comet_exp (comet_ml.Experiment, optional): experiment to use.
+ Defaults to None.
+ """
+ trainer = self.trainer
+ if trainer.exp is None:
+ return
+ curr_iter = self.global_step
+ nb_per_log = im_per_row * rows_per_log
+ n_logs = len(image_outputs) // nb_per_log + 1
+
+ header = None
+ if len(legends) == im_per_row and all(isinstance(t, str) for t in legends):
+ header_width = max(im.shape[-1] for im in image_outputs)
+ headers = all_texts_to_tensors(legends, width=header_width)
+ header = torch.cat(headers, dim=-1)
+
+ for logidx in range(n_logs):
+ print(" " * 100, end="\r", flush=True)
+ print(
+ "Uploading images for {} {} {} {}/{}".format(
+ mode, domain, task, logidx + 1, n_logs
+ ),
+ end="...",
+ flush=True,
+ )
+ ims = image_outputs[logidx * nb_per_log : (logidx + 1) * nb_per_log]
+ if not ims:
+ continue
+
+ ims = self.upsample(ims)
+ ims = torch.stack([im.squeeze() for im in ims]).squeeze()
+ image_grid = vutils.make_grid(
+ ims, nrow=im_per_row, normalize=True, scale_each=True, padding=0
+ )
+
+ if header is not None:
+ image_grid = torch.cat(
+ [header.to(image_grid.device), image_grid], dim=1
+ )
+
+ image_grid = image_grid.permute(1, 2, 0).cpu().numpy()
+ trainer.exp.log_image(
+ Image.fromarray((image_grid * 255).astype(np.uint8)),
+ name=f"{mode}_{domain}_{task}_{str(curr_iter)}_#{logidx}",
+ step=curr_iter,
+ )
+
+ def upsample(self, ims):
+ h = max(im.shape[-2] for im in ims)
+ w = max(im.shape[-1] for im in ims)
+ new_ims = []
+ for im in ims:
+ im = interpolate(im, (h, w), mode="bilinear")
+ new_ims.append(im)
+ return new_ims
+
+ def padd(self, ims):
+ h = max(im.shape[-2] for im in ims)
+ w = max(im.shape[-1] for im in ims)
+ new_ims = []
+ for im in ims:
+ ih = im.shape[-2]
+ iw = im.shape[-1]
+ if ih != h or iw != w:
+ padded = torch.zeros(im.shape[-3], h, w)
+ padded[
+ :, (h - ih) // 2 : (h + ih) // 2, (w - iw) // 2 : (w + iw) // 2
+ ] = im
+ new_ims.append(padded)
+ else:
+ new_ims.append(im)
+
+ return new_ims
+
+ def log_architecture(self):
+ write_architecture(self.trainer)
+
+ if self.trainer.exp is None:
+ return
+
+ for f in Path(self.trainer.opts.output_path).glob("archi*.txt"):
+ self.trainer.exp.log_asset(str(f), overwrite=True)
diff --git a/climategan/losses.py b/climategan/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..f10a5d26c73795bad02837f546b96c76b24e7564
--- /dev/null
+++ b/climategan/losses.py
@@ -0,0 +1,620 @@
+"""Define all losses. When possible, as inheriting from nn.Module
+To send predictions to target.device
+"""
+from random import random as rand
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torchvision import models
+
+
+class GANLoss(nn.Module):
+ def __init__(
+ self,
+ use_lsgan=True,
+ target_real_label=1.0,
+ target_fake_label=0.0,
+ soft_shift=0.0,
+ flip_prob=0.0,
+ verbose=0,
+ ):
+ """Defines the GAN loss which uses either LSGAN or the regular GAN.
+ When LSGAN is used, it is basically same as MSELoss,
+ but it abstracts away the need to create the target label tensor
+ that has the same size as the input +
+
+ * label smoothing: target_real_label=0.75
+ * label flipping: flip_prob > 0.
+
+ source: https://github.com/sangwoomo/instagan/blob
+ /b67e9008fcdd6c41652f8805f0b36bcaa8b632d6/models/networks.py
+
+ Args:
+ use_lsgan (bool, optional): Use MSE or BCE. Defaults to True.
+ target_real_label (float, optional): Value for the real target.
+ Defaults to 1.0.
+ target_fake_label (float, optional): Value for the fake target.
+ Defaults to 0.0.
+ flip_prob (float, optional): Probability of flipping the label
+ (use for real target in Discriminator only). Defaults to 0.0.
+ """
+ super().__init__()
+
+ self.soft_shift = soft_shift
+ self.verbose = verbose
+
+ self.register_buffer("real_label", torch.tensor(target_real_label))
+ self.register_buffer("fake_label", torch.tensor(target_fake_label))
+ if use_lsgan:
+ self.loss = nn.MSELoss()
+ else:
+ self.loss = nn.BCEWithLogitsLoss()
+ self.flip_prob = flip_prob
+
+ def get_target_tensor(self, input, target_is_real):
+ soft_change = torch.FloatTensor(1).uniform_(0, self.soft_shift)
+ if self.verbose > 0:
+ print("GANLoss sampled soft_change:", soft_change.item())
+ if target_is_real:
+ target_tensor = self.real_label - soft_change
+ else:
+ target_tensor = self.fake_label + soft_change
+ return target_tensor.expand_as(input)
+
+ def __call__(self, input, target_is_real, *args, **kwargs):
+ r = rand()
+ if isinstance(input, list):
+ loss = 0
+ for pred_i in input:
+ if isinstance(pred_i, list):
+ pred_i = pred_i[-1]
+ if r < self.flip_prob:
+ target_is_real = not target_is_real
+ target_tensor = self.get_target_tensor(pred_i, target_is_real)
+ loss_tensor = self.loss(pred_i, target_tensor.to(pred_i.device))
+ loss += loss_tensor
+ return loss / len(input)
+ else:
+ if r < self.flip_prob:
+ target_is_real = not target_is_real
+ target_tensor = self.get_target_tensor(input, target_is_real)
+ return self.loss(input, target_tensor.to(input.device))
+
+
+class FeatMatchLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.criterionFeat = nn.L1Loss()
+
+ def __call__(self, pred_real, pred_fake):
+ # pred_{real, fake} are lists of features
+ num_D = len(pred_fake)
+ GAN_Feat_loss = 0.0
+ for i in range(num_D): # for each discriminator
+ # last output is the final prediction, so we exclude it
+ num_intermediate_outputs = len(pred_fake[i]) - 1
+ for j in range(num_intermediate_outputs): # for each layer output
+ unweighted_loss = self.criterionFeat(
+ pred_fake[i][j], pred_real[i][j].detach()
+ )
+ GAN_Feat_loss += unweighted_loss / num_D
+ return GAN_Feat_loss
+
+
+class CrossEntropy(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.loss = nn.CrossEntropyLoss()
+
+ def __call__(self, logits, target):
+ return self.loss(logits, target.to(logits.device).long())
+
+
+class TravelLoss(nn.Module):
+ def __init__(self, eps=1e-12):
+ super().__init__()
+ self.eps = eps
+
+ def cosine_loss(self, real, fake):
+ norm_real = torch.norm(real, p=2, dim=1)[:, None]
+ norm_fake = torch.norm(fake, p=2, dim=1)[:, None]
+ mat_real = real / norm_real
+ mat_fake = fake / norm_fake
+ mat_real = torch.max(mat_real, self.eps * torch.ones_like(mat_real))
+ mat_fake = torch.max(mat_fake, self.eps * torch.ones_like(mat_fake))
+ # compute only the diagonal of the matrix multiplication
+ return torch.einsum("ij, ji -> i", mat_fake, mat_real).sum()
+
+ def __call__(self, S_real, S_fake):
+ self.v_real = []
+ self.v_fake = []
+ for i in range(len(S_real)):
+ for j in range(i):
+ self.v_real.append((S_real[i] - S_real[j])[None, :])
+ self.v_fake.append((S_fake[i] - S_fake[j])[None, :])
+ self.v_real_t = torch.cat(self.v_real, dim=0)
+ self.v_fake_t = torch.cat(self.v_fake, dim=0)
+ return self.cosine_loss(self.v_real_t, self.v_fake_t)
+
+
+class TVLoss(nn.Module):
+ """Total Variational Regularization: Penalizes differences in
+ neighboring pixel values
+
+ source:
+ https://github.com/jxgu1016/Total_Variation_Loss.pytorch/blob/master/TVLoss.py
+ """
+
+ def __init__(self, tvloss_weight=1):
+ """
+ Args:
+ TVLoss_weight (int, optional): [lambda i.e. weight for loss]. Defaults to 1.
+ """
+ super(TVLoss, self).__init__()
+ self.tvloss_weight = tvloss_weight
+
+ def forward(self, x):
+ batch_size = x.size()[0]
+ h_x = x.size()[2]
+ w_x = x.size()[3]
+ count_h = self._tensor_size(x[:, :, 1:, :])
+ count_w = self._tensor_size(x[:, :, :, 1:])
+ h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, : h_x - 1, :]), 2).sum()
+ w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, : w_x - 1]), 2).sum()
+ return self.tvloss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
+
+ def _tensor_size(self, t):
+ return t.size()[1] * t.size()[2] * t.size()[3]
+
+
+class MinentLoss(nn.Module):
+ """
+ Loss for the minimization of the entropy map
+ Source for version 1: https://github.com/valeoai/ADVENT
+
+ Version 2 adds the variance of the entropy map in the computation of the loss
+ """
+
+ def __init__(self, version=1, lambda_var=0.1):
+ super().__init__()
+ self.version = version
+ self.lambda_var = lambda_var
+
+ def __call__(self, pred):
+ assert pred.dim() == 4
+ n, c, h, w = pred.size()
+ entropy_map = -torch.mul(pred, torch.log2(pred + 1e-30)) / np.log2(c)
+ if self.version == 1:
+ return torch.sum(entropy_map) / (n * h * w)
+ else:
+ entropy_map_demean = entropy_map - torch.sum(entropy_map) / (n * h * w)
+ entropy_map_squ = torch.mul(entropy_map_demean, entropy_map_demean)
+ return torch.sum(entropy_map + self.lambda_var * entropy_map_squ) / (
+ n * h * w
+ )
+
+
+class MSELoss(nn.Module):
+ """
+ Creates a criterion that measures the mean squared error
+ (squared L2 norm) between each element in the input x and target y .
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.loss = nn.MSELoss()
+
+ def __call__(self, prediction, target):
+ return self.loss(prediction, target.to(prediction.device))
+
+
+class L1Loss(MSELoss):
+ """
+ Creates a criterion that measures the mean absolute error
+ (MAE) between each element in the input x and target y
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.loss = nn.L1Loss()
+
+
+class SIMSELoss(nn.Module):
+ """Scale invariant MSE Loss"""
+
+ def __init__(self):
+ super(SIMSELoss, self).__init__()
+
+ def __call__(self, prediction, target):
+ d = prediction - target
+ diff = torch.mean(d * d)
+ relDiff = torch.mean(d) * torch.mean(d)
+ return diff - relDiff
+
+
+class SIGMLoss(nn.Module):
+ """loss from MiDaS paper
+ MiDaS did not specify how the gradients were computed but we use Sobel
+ filters which approximate the derivative of an image.
+ """
+
+ def __init__(self, gmweight=0.5, scale=4, device="cuda"):
+ super(SIGMLoss, self).__init__()
+ self.gmweight = gmweight
+ self.sobelx = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).to(device)
+ self.sobely = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).to(device)
+ self.scale = scale
+
+ def __call__(self, prediction, target):
+ # get disparities
+ # align both the prediction and the ground truth to have zero
+ # translation and unit scale
+ t_pred = torch.median(prediction)
+ t_targ = torch.median(target)
+ s_pred = torch.mean(torch.abs(prediction - t_pred))
+ s_targ = torch.mean(torch.abs(target - t_targ))
+ pred = (prediction - t_pred) / s_pred
+ targ = (target - t_targ) / s_targ
+
+ R = pred - targ
+
+ # get gradient map with sobel filters
+ batch_size = prediction.size()[0]
+ num_pix = prediction.size()[-1] * prediction.size()[-2]
+ sobelx = (self.sobelx).expand((batch_size, 1, -1, -1))
+ sobely = (self.sobely).expand((batch_size, 1, -1, -1))
+ gmLoss = 0 # gradient matching term
+ for k in range(self.scale):
+ R_ = F.interpolate(R, scale_factor=1 / 2 ** k)
+ Rx = F.conv2d(R_, sobelx, stride=1)
+ Ry = F.conv2d(R_, sobely, stride=1)
+ gmLoss += torch.sum(torch.abs(Rx) + torch.abs(Ry))
+ gmLoss = self.gmweight / num_pix * gmLoss
+ # scale invariant MSE
+ simseLoss = 0.5 / num_pix * torch.sum(torch.abs(R))
+ loss = simseLoss + gmLoss
+ return loss
+
+
+class ContextLoss(nn.Module):
+ """
+ Masked L1 loss on non-water
+ """
+
+ def __call__(self, input, target, mask):
+ return torch.mean(torch.abs(torch.mul((input - target), 1 - mask)))
+
+
+class ReconstructionLoss(nn.Module):
+ """
+ Masked L1 loss on water
+ """
+
+ def __call__(self, input, target, mask):
+ return torch.mean(torch.abs(torch.mul((input - target), mask)))
+
+
+##################################################################################
+# VGG network definition
+##################################################################################
+
+# Source: https://github.com/NVIDIA/pix2pixHD
+class Vgg19(nn.Module):
+ def __init__(self, requires_grad=False):
+ super(Vgg19, self).__init__()
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
+ self.slice1 = nn.Sequential()
+ self.slice2 = nn.Sequential()
+ self.slice3 = nn.Sequential()
+ self.slice4 = nn.Sequential()
+ self.slice5 = nn.Sequential()
+ for x in range(2):
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(2, 7):
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(7, 12):
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(12, 21):
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
+ for x in range(21, 30):
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
+ if not requires_grad:
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, X):
+ h_relu1 = self.slice1(X)
+ h_relu2 = self.slice2(h_relu1)
+ h_relu3 = self.slice3(h_relu2)
+ h_relu4 = self.slice4(h_relu3)
+ h_relu5 = self.slice5(h_relu4)
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
+ return out
+
+
+# Source: https://github.com/NVIDIA/pix2pixHD
+class VGGLoss(nn.Module):
+ def __init__(self, device):
+ super().__init__()
+ self.vgg = Vgg19().to(device).eval()
+ self.criterion = nn.L1Loss()
+ self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
+
+ def forward(self, x, y):
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
+ loss = 0
+ for i in range(len(x_vgg)):
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
+ return loss
+
+
+def get_losses(opts, verbose, device=None):
+ """Sets the loss functions to be used by G, D and C, as specified
+ in the opts and returns a dictionnary of losses:
+
+ losses = {
+ "G": {
+ "gan": {"a": ..., "t": ...},
+ "cycle": {"a": ..., "t": ...}
+ "auto": {"a": ..., "t": ...}
+ "tasks": {"h": ..., "d": ..., "s": ..., etc.}
+ },
+ "D": GANLoss,
+ "C": ...
+ }
+ """
+
+ losses = {
+ "G": {"a": {}, "p": {}, "tasks": {}},
+ "D": {"default": {}, "advent": {}},
+ "C": {},
+ }
+
+ # ------------------------------
+ # ----- Generator Losses -----
+ # ------------------------------
+
+ # painter losses
+ if "p" in opts.tasks:
+ losses["G"]["p"]["gan"] = (
+ HingeLoss()
+ if opts.gen.p.loss == "hinge"
+ else GANLoss(
+ use_lsgan=False,
+ soft_shift=opts.dis.soft_shift,
+ flip_prob=opts.dis.flip_prob,
+ )
+ )
+ losses["G"]["p"]["dm"] = MSELoss()
+ losses["G"]["p"]["vgg"] = VGGLoss(device)
+ losses["G"]["p"]["tv"] = TVLoss()
+ losses["G"]["p"]["context"] = ContextLoss()
+ losses["G"]["p"]["reconstruction"] = ReconstructionLoss()
+ losses["G"]["p"]["featmatch"] = FeatMatchLoss()
+
+ # depth losses
+ if "d" in opts.tasks:
+ if not opts.gen.d.classify.enable:
+ if opts.gen.d.loss == "dada":
+ depth_func = DADADepthLoss()
+ else:
+ depth_func = SIGMLoss(opts.train.lambdas.G.d.gml)
+ else:
+ depth_func = CrossEntropy()
+
+ losses["G"]["tasks"]["d"] = depth_func
+
+ # segmentation losses
+ if "s" in opts.tasks:
+ losses["G"]["tasks"]["s"] = {}
+ losses["G"]["tasks"]["s"]["crossent"] = CrossEntropy()
+ losses["G"]["tasks"]["s"]["minent"] = MinentLoss()
+ losses["G"]["tasks"]["s"]["advent"] = ADVENTAdversarialLoss(
+ opts, gan_type=opts.dis.s.gan_type
+ )
+
+ # masker losses
+ if "m" in opts.tasks:
+ losses["G"]["tasks"]["m"] = {}
+ losses["G"]["tasks"]["m"]["bce"] = nn.BCEWithLogitsLoss()
+ if opts.gen.m.use_minent_var:
+ losses["G"]["tasks"]["m"]["minent"] = MinentLoss(
+ version=2, lambda_var=opts.train.lambdas.advent.ent_var
+ )
+ else:
+ losses["G"]["tasks"]["m"]["minent"] = MinentLoss()
+ losses["G"]["tasks"]["m"]["tv"] = TVLoss()
+ losses["G"]["tasks"]["m"]["advent"] = ADVENTAdversarialLoss(
+ opts, gan_type=opts.dis.m.gan_type
+ )
+ losses["G"]["tasks"]["m"]["gi"] = GroundIntersectionLoss()
+
+ # ----------------------------------
+ # ----- Discriminator Losses -----
+ # ----------------------------------
+ if "p" in opts.tasks:
+ losses["D"]["p"] = losses["G"]["p"]["gan"]
+ if "m" in opts.tasks or "s" in opts.tasks:
+ losses["D"]["advent"] = ADVENTAdversarialLoss(opts)
+ return losses
+
+
+class GroundIntersectionLoss(nn.Module):
+ """
+ Penalize areas in ground seg but not in flood mask
+ """
+
+ def __call__(self, pred, pseudo_ground):
+ return torch.mean(1.0 * ((pseudo_ground - pred) > 0.5))
+
+
+def prob_2_entropy(prob):
+ """
+ convert probabilistic prediction maps to weighted self-information maps
+ """
+ n, c, h, w = prob.size()
+ return -torch.mul(prob, torch.log2(prob + 1e-30)) / np.log2(c)
+
+
+class CustomBCELoss(nn.Module):
+ """
+ The first argument is a tensor and the second argument is an int.
+ There is no need to take sigmoid before calling this function.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.loss = nn.BCEWithLogitsLoss()
+
+ def __call__(self, prediction, target):
+ return self.loss(
+ prediction,
+ torch.FloatTensor(prediction.size())
+ .fill_(target)
+ .to(prediction.get_device()),
+ )
+
+
+class ADVENTAdversarialLoss(nn.Module):
+ """
+ The class is for calculating the advent loss.
+ It is used to indirectly shrink the domain gap between sim and real
+
+ _call_ function:
+ prediction: torch.tensor with shape of [bs,c,h,w]
+ target: int; domain label: 0 (sim) or 1 (real)
+ discriminator: the discriminator model tells if a tensor is from sim or real
+
+ output: the loss value of GANLoss
+ """
+
+ def __init__(self, opts, gan_type="GAN"):
+ super().__init__()
+ self.opts = opts
+ if gan_type == "GAN":
+ self.loss = CustomBCELoss()
+ elif gan_type == "WGAN" or "WGAN_gp" or "WGAN_norm":
+ self.loss = lambda x, y: -torch.mean(y * x + (1 - y) * (1 - x))
+ else:
+ raise NotImplementedError
+
+ def __call__(self, prediction, target, discriminator, depth_preds=None):
+ """
+ Compute the GAN loss from the Advent Discriminator given
+ normalized (softmaxed) predictions (=pixel-wise class probabilities),
+ and int labels (target).
+
+ Args:
+ prediction (torch.Tensor): pixel-wise probability distribution over classes
+ target (torch.Tensor): pixel wise int target labels
+ discriminator (torch.nn.Module): Discriminator to get the loss
+
+ Returns:
+ torch.Tensor: float 0-D loss
+ """
+ d_out = prob_2_entropy(prediction)
+ if depth_preds is not None:
+ d_out = d_out * depth_preds
+ d_out = discriminator(d_out)
+ if self.opts.dis.m.architecture == "OmniDiscriminator":
+ d_out = multiDiscriminatorAdapter(d_out, self.opts)
+ loss_ = self.loss(d_out, target)
+ return loss_
+
+
+def multiDiscriminatorAdapter(d_out: list, opts: dict) -> torch.tensor:
+ """
+ Because the OmniDiscriminator does not directly return a tensor
+ (but a list of tensor).
+ Since there is no multilevel masker, the 0th tensor in the list is all we want.
+ This Adapter returns the first element(tensor) of the list that OmniDiscriminator
+ returns.
+ """
+ if (
+ isinstance(d_out, list) and len(d_out) == 1
+ ): # adapt the multi-scale OmniDiscriminator
+ if not opts.dis.p.get_intermediate_features:
+ d_out = d_out[0][0]
+ else:
+ d_out = d_out[0]
+ else:
+ raise Exception(
+ "Check the setting of OmniDiscriminator! "
+ + "For now, we don't support multi-scale OmniDiscriminator."
+ )
+ return d_out
+
+
+class HingeLoss(nn.Module):
+ """
+ Adapted from https://github.com/NVlabs/SPADE/blob/master/models/networks/loss.py
+ for the painter
+ """
+
+ def __init__(self, tensor=torch.FloatTensor):
+ super().__init__()
+ self.zero_tensor = None
+ self.Tensor = tensor
+
+ def get_zero_tensor(self, input):
+ if self.zero_tensor is None:
+ self.zero_tensor = self.Tensor(1).fill_(0)
+ self.zero_tensor.requires_grad_(False)
+ self.zero_tensor = self.zero_tensor.to(input.device)
+ return self.zero_tensor.expand_as(input)
+
+ def loss(self, input, target_is_real, for_discriminator=True):
+ if for_discriminator:
+ if target_is_real:
+ minval = torch.min(input - 1, self.get_zero_tensor(input))
+ loss = -torch.mean(minval)
+ else:
+ minval = torch.min(-input - 1, self.get_zero_tensor(input))
+ loss = -torch.mean(minval)
+ else:
+ assert target_is_real, "The generator's hinge loss must be aiming for real"
+ loss = -torch.mean(input)
+ return loss
+
+ def __call__(self, input, target_is_real, for_discriminator=True):
+ # computing loss is a bit complicated because |input| may not be
+ # a tensor, but list of tensors in case of multiscale discriminator
+ if isinstance(input, list):
+ loss = 0
+ for pred_i in input:
+ if isinstance(pred_i, list):
+ pred_i = pred_i[-1]
+ loss_tensor = self.loss(pred_i, target_is_real, for_discriminator)
+ loss += loss_tensor
+ return loss / len(input)
+ else:
+ return self.loss(input, target_is_real, for_discriminator)
+
+
+class DADADepthLoss:
+ """Defines the reverse Huber loss from DADA paper for depth prediction
+ - Samples with larger residuals are penalized more by l2 term
+ - Samples with smaller residuals are penalized more by l1 term
+ From https://github.com/valeoai/DADA/blob/master/dada/utils/func.py
+ """
+
+ def loss_calc_depth(self, pred, label):
+ n, c, h, w = pred.size()
+ assert c == 1
+
+ pred = pred.squeeze()
+ label = label.squeeze()
+
+ adiff = torch.abs(pred - label)
+ batch_max = 0.2 * torch.max(adiff).item()
+ t1_mask = adiff.le(batch_max).float()
+ t2_mask = adiff.gt(batch_max).float()
+ t1 = adiff * t1_mask
+ t2 = (adiff * adiff + batch_max * batch_max) / (2 * batch_max)
+ t2 = t2 * t2_mask
+ return (torch.sum(t1) + torch.sum(t2)) / torch.numel(pred.data)
+
+ def __call__(self, pred, label):
+ return self.loss_calc_depth(pred, label)
diff --git a/climategan/masker.py b/climategan/masker.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bbd3063c29f55d80d38fd41db8f5d534f62c6d3
--- /dev/null
+++ b/climategan/masker.py
@@ -0,0 +1,234 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from climategan.blocks import (
+ BaseDecoder,
+ Conv2dBlock,
+ InterpolateNearest2d,
+ SPADEResnetBlock,
+)
+
+
+def create_mask_decoder(opts, no_init=False, verbose=0):
+ if opts.gen.m.use_spade:
+ if verbose > 0:
+ print(" - Add Spade Mask Decoder")
+ assert "d" in opts.tasks or "s" in opts.tasks
+ return MaskSpadeDecoder(opts)
+ else:
+ if verbose > 0:
+ print(" - Add Base Mask Decoder")
+ return MaskBaseDecoder(opts)
+
+
+class MaskBaseDecoder(BaseDecoder):
+ def __init__(self, opts):
+ low_level_feats_dim = -1
+ use_v3 = opts.gen.encoder.architecture == "deeplabv3"
+ use_mobile_net = opts.gen.deeplabv3.backbone == "mobilenet"
+ use_low = opts.gen.m.use_low_level_feats
+ use_dada = ("d" in opts.tasks) and opts.gen.m.use_dada
+
+ if use_v3 and use_mobile_net:
+ input_dim = 320
+ if use_low:
+ low_level_feats_dim = 24
+ elif use_v3:
+ input_dim = 2048
+ if use_low:
+ low_level_feats_dim = 256
+ else:
+ input_dim = 2048
+
+ super().__init__(
+ n_upsample=opts.gen.m.n_upsample,
+ n_res=opts.gen.m.n_res,
+ input_dim=input_dim,
+ proj_dim=opts.gen.m.proj_dim,
+ output_dim=opts.gen.m.output_dim,
+ norm=opts.gen.m.norm,
+ activ=opts.gen.m.activ,
+ pad_type=opts.gen.m.pad_type,
+ output_activ="none",
+ low_level_feats_dim=low_level_feats_dim,
+ use_dada=use_dada,
+ )
+
+
+class MaskSpadeDecoder(nn.Module):
+ def __init__(self, opts):
+ """Create a SPADE-based decoder, which forwards z and the conditioning
+ tensors seg (in the original paper, conditioning is on a semantic map only).
+ All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink
+ the channel dimension, and an upsampling is applied after each. Therefore
+ 2 upsamplings at this point. Then, for each remaining upsamplings
+ (w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3
+ channels, the number of channels is therefore:
+ final_nc = channels(z) * 2 ** (spade_n_up - 2)
+ Args:
+ latent_dim (tuple): z's shape (only the number of channels matters)
+ cond_nc (int): conditioning tensor's expected number of channels
+ spade_n_up (int): Number of total upsamplings from z
+ spade_use_spectral_norm (bool): use spectral normalization?
+ spade_param_free_norm (str): norm to use before SPADE de-normalization
+ spade_kernel_size (int): SPADE conv layers' kernel size
+ Returns:
+ [type]: [description]
+ """
+ super().__init__()
+ self.opts = opts
+ latent_dim = opts.gen.m.spade.latent_dim
+ cond_nc = opts.gen.m.spade.cond_nc
+ spade_use_spectral_norm = opts.gen.m.spade.spade_use_spectral_norm
+ spade_param_free_norm = opts.gen.m.spade.spade_param_free_norm
+ if self.opts.gen.m.spade.activations.all_lrelu:
+ spade_activation = "lrelu"
+ else:
+ spade_activation = None
+ spade_kernel_size = 3
+ self.num_layers = opts.gen.m.spade.num_layers
+ self.z_nc = latent_dim
+
+ if (
+ opts.gen.encoder.architecture == "deeplabv3"
+ and opts.gen.deeplabv3.backbone == "mobilenet"
+ ):
+ self.input_dim = [320, 24]
+ self.low_level_conv = Conv2dBlock(
+ self.input_dim[1],
+ self.input_dim[0],
+ 3,
+ padding=1,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="spectral_batch",
+ )
+ self.merge_feats_conv = Conv2dBlock(
+ self.input_dim[0] * 2,
+ self.z_nc,
+ 3,
+ padding=1,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="spectral_batch",
+ )
+ elif (
+ opts.gen.encoder.architecture == "deeplabv3"
+ and opts.gen.deeplabv3.backbone == "resnet"
+ ):
+ self.input_dim = [2048, 256]
+ if self.opts.gen.m.use_proj:
+ proj_dim = self.opts.gen.m.proj_dim
+ self.low_level_conv = Conv2dBlock(
+ self.input_dim[1],
+ proj_dim,
+ 3,
+ padding=1,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="spectral_batch",
+ )
+ self.high_level_conv = Conv2dBlock(
+ self.input_dim[0],
+ proj_dim,
+ 3,
+ padding=1,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="spectral_batch",
+ )
+ self.merge_feats_conv = Conv2dBlock(
+ proj_dim * 2,
+ self.z_nc,
+ 3,
+ padding=1,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="spectral_batch",
+ )
+ else:
+ self.low_level_conv = Conv2dBlock(
+ self.input_dim[1],
+ self.input_dim[0],
+ 3,
+ padding=1,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="spectral_batch",
+ )
+ self.merge_feats_conv = Conv2dBlock(
+ self.input_dim[0] * 2,
+ self.z_nc,
+ 3,
+ padding=1,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="spectral_batch",
+ )
+
+ elif opts.gen.encoder.architecture == "deeplabv2":
+ self.input_dim = 2048
+ self.fc_conv = Conv2dBlock(
+ self.input_dim,
+ self.z_nc,
+ 3,
+ padding=1,
+ activation="lrelu",
+ pad_type="reflect",
+ norm="spectral_batch",
+ )
+ else:
+ raise ValueError("Unknown encoder type")
+
+ self.spade_blocks = []
+
+ for i in range(self.num_layers):
+ self.spade_blocks.append(
+ SPADEResnetBlock(
+ int(self.z_nc / (2**i)),
+ int(self.z_nc / (2 ** (i + 1))),
+ cond_nc,
+ spade_use_spectral_norm,
+ spade_param_free_norm,
+ spade_kernel_size,
+ spade_activation,
+ )
+ )
+ self.spade_blocks = nn.Sequential(*self.spade_blocks)
+
+ self.final_nc = int(self.z_nc / (2**self.num_layers))
+ self.mask_conv = Conv2dBlock(
+ self.final_nc,
+ 1,
+ 3,
+ padding=1,
+ activation="none",
+ pad_type="reflect",
+ norm="spectral",
+ )
+ self.upsample = InterpolateNearest2d(scale_factor=2)
+
+ def forward(self, z, cond, z_depth=None):
+ if isinstance(z, (list, tuple)):
+ z_h, z_l = z
+ if self.opts.gen.m.use_proj:
+ z_l = self.low_level_conv(z_l)
+ z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear")
+ z_h = self.high_level_conv(z_h)
+ else:
+ z_l = self.low_level_conv(z_l)
+ z_l = F.interpolate(z_l, size=z_h.shape[-2:], mode="bilinear")
+ z = torch.cat([z_h, z_l], axis=1)
+ y = self.merge_feats_conv(z)
+ else:
+ y = self.fc_conv(z)
+
+ for i in range(self.num_layers):
+ y = self.spade_blocks[i](y, cond)
+ y = self.upsample(y)
+ y = self.mask_conv(y)
+ return y
+
+ def __str__(self):
+ return "MaskerSpadeDecoder"
diff --git a/climategan/norms.py b/climategan/norms.py
new file mode 100644
index 0000000000000000000000000000000000000000..c448248488af0baf131628e994cb17df20a58cbd
--- /dev/null
+++ b/climategan/norms.py
@@ -0,0 +1,186 @@
+"""Normalization layers used in blocks
+"""
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class AdaptiveInstanceNorm2d(nn.Module):
+ def __init__(self, num_features, eps=1e-5, momentum=0.1):
+ super(AdaptiveInstanceNorm2d, self).__init__()
+ self.num_features = num_features
+ self.eps = eps
+ self.momentum = momentum
+ # weight and bias are dynamically assigned
+ self.weight = None
+ self.bias = None
+ # just dummy buffers, not used
+ self.register_buffer("running_mean", torch.zeros(num_features))
+ self.register_buffer("running_var", torch.ones(num_features))
+
+ def forward(self, x):
+ assert (
+ self.weight is not None and self.bias is not None
+ ), "Please assign weight and bias before calling AdaIN!"
+ b, c = x.size(0), x.size(1)
+ running_mean = self.running_mean.repeat(b)
+ running_var = self.running_var.repeat(b)
+
+ # Apply instance norm
+ x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:])
+
+ out = F.batch_norm(
+ x_reshaped,
+ running_mean,
+ running_var,
+ self.weight,
+ self.bias,
+ True,
+ self.momentum,
+ self.eps,
+ )
+
+ return out.view(b, c, *x.size()[2:])
+
+ def __repr__(self):
+ return self.__class__.__name__ + "(" + str(self.num_features) + ")"
+
+
+class LayerNorm(nn.Module):
+ def __init__(self, num_features, eps=1e-5, affine=True):
+ super(LayerNorm, self).__init__()
+ self.num_features = num_features
+ self.affine = affine
+ self.eps = eps
+
+ if self.affine:
+ self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_())
+ self.beta = nn.Parameter(torch.zeros(num_features))
+
+ def forward(self, x):
+ shape = [-1] + [1] * (x.dim() - 1)
+ # print(x.size())
+ if x.size(0) == 1:
+ # These two lines run much faster in pytorch 0.4
+ # than the two lines listed below.
+ mean = x.view(-1).mean().view(*shape)
+ std = x.view(-1).std().view(*shape)
+ else:
+ mean = x.view(x.size(0), -1).mean(1).view(*shape)
+ std = x.view(x.size(0), -1).std(1).view(*shape)
+
+ x = (x - mean) / (std + self.eps)
+
+ if self.affine:
+ shape = [1, -1] + [1] * (x.dim() - 2)
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
+ return x
+
+
+def l2normalize(v, eps=1e-12):
+ return v / (v.norm() + eps)
+
+
+class SpectralNorm(nn.Module):
+ """
+ Based on the paper "Spectral Normalization for Generative Adversarial Networks"
+ by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida and the
+ Pytorch implementation:
+ https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
+ """
+
+ def __init__(self, module, name="weight", power_iterations=1):
+ super().__init__()
+ self.module = module
+ self.name = name
+ self.power_iterations = power_iterations
+ if not self._made_params():
+ self._make_params()
+
+ def _update_u_v(self):
+ u = getattr(self.module, self.name + "_u")
+ v = getattr(self.module, self.name + "_v")
+ w = getattr(self.module, self.name + "_bar")
+
+ height = w.data.shape[0]
+ for _ in range(self.power_iterations):
+ v.data = l2normalize(torch.mv(torch.t(w.view(height, -1).data), u.data))
+ u.data = l2normalize(torch.mv(w.view(height, -1).data, v.data))
+
+ # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data))
+ sigma = u.dot(w.view(height, -1).mv(v))
+ setattr(self.module, self.name, w / sigma.expand_as(w))
+
+ def _made_params(self):
+ try:
+ u = getattr(self.module, self.name + "_u") # noqa: F841
+ v = getattr(self.module, self.name + "_v") # noqa: F841
+ w = getattr(self.module, self.name + "_bar") # noqa: F841
+ return True
+ except AttributeError:
+ return False
+
+ def _make_params(self):
+ w = getattr(self.module, self.name)
+
+ height = w.data.shape[0]
+ width = w.view(height, -1).data.shape[1]
+
+ u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
+ v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False)
+ u.data = l2normalize(u.data)
+ v.data = l2normalize(v.data)
+ w_bar = nn.Parameter(w.data)
+
+ del self.module._parameters[self.name]
+
+ self.module.register_parameter(self.name + "_u", u)
+ self.module.register_parameter(self.name + "_v", v)
+ self.module.register_parameter(self.name + "_bar", w_bar)
+
+ def forward(self, *args):
+ self._update_u_v()
+ return self.module.forward(*args)
+
+
+class SPADE(nn.Module):
+ def __init__(self, param_free_norm_type, kernel_size, norm_nc, cond_nc):
+ super().__init__()
+
+ if param_free_norm_type == "instance":
+ self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
+ # elif param_free_norm_type == "syncbatch":
+ # self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
+ elif param_free_norm_type == "batch":
+ self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
+ else:
+ raise ValueError(
+ "%s is not a recognized param-free norm type in SPADE"
+ % param_free_norm_type
+ )
+
+ # The dimension of the intermediate embedding space. Yes, hardcoded.
+ nhidden = 128
+
+ pw = kernel_size // 2
+ self.mlp_shared = nn.Sequential(
+ nn.Conv2d(cond_nc, nhidden, kernel_size=kernel_size, padding=pw), nn.ReLU()
+ )
+ self.mlp_gamma = nn.Conv2d(
+ nhidden, norm_nc, kernel_size=kernel_size, padding=pw
+ )
+ self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=kernel_size, padding=pw)
+
+ def forward(self, x, segmap):
+ # Part 1. generate parameter-free normalized activations
+ normalized = self.param_free_norm(x)
+
+ # Part 2. produce scaling and bias conditioned on semantic map
+ segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
+ actv = self.mlp_shared(segmap)
+ gamma = self.mlp_gamma(actv)
+ beta = self.mlp_beta(actv)
+ # apply scale and bias
+ out = normalized * (1 + gamma) + beta
+
+ return out
diff --git a/climategan/optim.py b/climategan/optim.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e6ffea333aedcb4b06ed5fcf7306affc453bee1
--- /dev/null
+++ b/climategan/optim.py
@@ -0,0 +1,291 @@
+"""Define ExtraAdam and schedulers
+"""
+import math
+
+import torch
+from torch.optim import Adam, Optimizer, RMSprop, lr_scheduler
+from torch_optimizer import NovoGrad, RAdam
+
+
+def get_scheduler(optimizer, hyperparameters, iterations=-1):
+ """Get an optimizer's learning rate scheduler based on opts
+
+ Args:
+ optimizer (torch.Optimizer): optimizer for which to schedule the learning rate
+ hyperparameters (addict.Dict): configuration options
+ iterations (int, optional): The index of last epoch. Defaults to -1.
+ When last_epoch=-1, sets initial lr as lr.
+
+ Returns:
+ [type]: [description]
+ """
+
+ policy = hyperparameters.get("lr_policy")
+ lr_step_size = hyperparameters.get("lr_step_size")
+ lr_gamma = hyperparameters.get("lr_gamma")
+ milestones = hyperparameters.get("lr_milestones")
+
+ if policy is None or policy == "constant":
+ scheduler = None # constant scheduler
+ elif policy == "step":
+ scheduler = lr_scheduler.StepLR(
+ optimizer, step_size=lr_step_size, gamma=lr_gamma, last_epoch=iterations,
+ )
+ elif policy == "multi_step":
+ if isinstance(milestones, (list, tuple)):
+ milestones = milestones
+ elif isinstance(milestones, int):
+ assert "lr_step_size" in hyperparameters
+ if iterations == -1:
+ last_milestone = 1000
+ else:
+ last_milestone = iterations
+ milestones = list(range(milestones, last_milestone, lr_step_size))
+ scheduler = lr_scheduler.MultiStepLR(
+ optimizer, milestones=milestones, gamma=lr_gamma, last_epoch=iterations,
+ )
+ else:
+ return NotImplementedError(
+ "learning rate policy [%s] is not implemented", hyperparameters["lr_policy"]
+ )
+ return scheduler
+
+
+def get_optimizer(net, opt_conf, tasks=None, is_disc=False, iterations=-1):
+ """Returns a tuple (optimizer, scheduler) according to opt_conf which
+ should come from the trainer's opts as: trainer.opts..opt
+
+ Args:
+ net (nn.Module): Network to update
+ opt_conf (addict.Dict): optimizer and scheduler options
+ tasks: list of tasks
+ iterations (int, optional): Last epoch number. Defaults to -1, meaning
+ start with base lr.
+
+ Returns:
+ Tuple: (torch.Optimizer, torch._LRScheduler)
+ """
+ opt = scheduler = None
+ lr_names = []
+ if tasks is None:
+ lr_default = opt_conf.lr
+ params = net.parameters()
+ lr_names.append("full")
+ elif isinstance(opt_conf.lr, float): # Use default for all tasks
+ lr_default = opt_conf.lr
+ params = net.parameters()
+ lr_names.append("full")
+ elif len(opt_conf.lr) == 1: # Use default for all tasks
+ lr_default = opt_conf.lr.default
+ params = net.parameters()
+ lr_names.append("full")
+ else:
+ lr_default = opt_conf.lr.default
+ params = list()
+ for task in tasks:
+ lr = opt_conf.lr.get(task, lr_default)
+ parameters = None
+ # Parameters for encoder
+ if not is_disc:
+ if task == "m":
+ parameters = net.encoder.parameters()
+ params.append({"params": parameters, "lr": lr})
+ lr_names.append("encoder")
+ # Parameters for decoders
+ if task == "p":
+ if hasattr(net, "painter"):
+ parameters = net.painter.parameters()
+ lr_names.append("painter")
+ else:
+ parameters = net.decoders[task].parameters()
+ lr_names.append(f"decoder_{task}")
+ else:
+ if task in net:
+ parameters = net[task].parameters()
+ lr_names.append(f"disc_{task}")
+
+ if parameters is not None:
+ params.append({"params": parameters, "lr": lr})
+
+ if opt_conf.optimizer.lower() == "extraadam":
+ opt = ExtraAdam(params, lr=lr_default, betas=(opt_conf.beta1, 0.999))
+ elif opt_conf.optimizer.lower() == "novograd":
+ opt = NovoGrad(
+ params, lr=lr_default, betas=(opt_conf.beta1, 0)
+ ) # default for beta2 is 0
+ elif opt_conf.optimizer.lower() == "radam":
+ opt = RAdam(params, lr=lr_default, betas=(opt_conf.beta1, 0.999))
+ elif opt_conf.optimizer.lower() == "rmsprop":
+ opt = RMSprop(params, lr=lr_default)
+ else:
+ opt = Adam(params, lr=lr_default, betas=(opt_conf.beta1, 0.999))
+ scheduler = get_scheduler(opt, opt_conf, iterations)
+ return opt, scheduler, lr_names
+
+
+"""
+Extragradient Optimizer
+
+Mostly copied from the extragrad paper repo.
+
+MIT License
+Copyright (c) Facebook, Inc. and its affiliates.
+written by Hugo Berard (berard.hugo@gmail.com) while at Facebook.
+"""
+
+
+class Extragradient(Optimizer):
+ """Base class for optimizers with extrapolation step.
+ Arguments:
+ params (iterable): an iterable of :class:`torch.Tensor` s or
+ :class:`dict` s. Specifies what Tensors should be optimized.
+ defaults: (dict): a dict containing default values of optimization
+ options (used when a parameter group doesn't specify them).
+ """
+
+ def __init__(self, params, defaults):
+ super(Extragradient, self).__init__(params, defaults)
+ self.params_copy = []
+
+ def update(self, p, group):
+ raise NotImplementedError
+
+ def extrapolation(self):
+ """Performs the extrapolation step and save a copy of the current
+ parameters for the update step.
+ """
+ # Check if a copy of the parameters was already made.
+ is_empty = len(self.params_copy) == 0
+ for group in self.param_groups:
+ for p in group["params"]:
+ u = self.update(p, group)
+ if is_empty:
+ # Save the current parameters for the update step.
+ # Several extrapolation step can be made before each update but
+ # only the parametersbefore the first extrapolation step are saved.
+ self.params_copy.append(p.data.clone())
+ if u is None:
+ continue
+ # Update the current parameters
+ p.data.add_(u)
+
+ def step(self, closure=None):
+ """Performs a single optimization step.
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ if len(self.params_copy) == 0:
+ raise RuntimeError("Need to call extrapolation before calling step.")
+
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ i = -1
+ for group in self.param_groups:
+ for p in group["params"]:
+ i += 1
+ u = self.update(p, group)
+ if u is None:
+ continue
+ # Update the parameters saved during the extrapolation step
+ p.data = self.params_copy[i].add_(u)
+
+ # Free the old parameters
+ self.params_copy = []
+ return loss
+
+
+class ExtraAdam(Extragradient):
+ """Implements the Adam algorithm with extrapolation step.
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
+ amsgrad (boolean, optional): whether to use the AMSGrad variant of this
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.999),
+ eps=1e-8,
+ weight_decay=0,
+ amsgrad=False,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ defaults = dict(
+ lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad
+ )
+ super(ExtraAdam, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(ExtraAdam, self).__setstate__(state)
+ for group in self.param_groups:
+ group.setdefault("amsgrad", False)
+
+ def update(self, p, group):
+ if p.grad is None:
+ return None
+ grad = p.grad.data
+ if grad.is_sparse:
+ raise RuntimeError(
+ "Adam does not support sparse gradients,"
+ + " please consider SparseAdam instead"
+ )
+ amsgrad = group["amsgrad"]
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(p.data)
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros_like(p.data)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state["max_exp_avg_sq"] = torch.zeros_like(p.data)
+
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+ if amsgrad:
+ max_exp_avg_sq = state["max_exp_avg_sq"]
+ beta1, beta2 = group["betas"]
+
+ state["step"] += 1
+
+ if group["weight_decay"] != 0:
+ grad = grad.add(group["weight_decay"], p.data)
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ if amsgrad:
+ # Maintains the maximum of all 2nd moment running avg. till now
+ torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) # type: ignore
+ # Use the max. for normalizing running avg. of gradient
+ denom = max_exp_avg_sq.sqrt().add_(group["eps"]) # type: ignore
+ else:
+ denom = exp_avg_sq.sqrt().add_(group["eps"])
+
+ bias_correction1 = 1 - beta1 ** state["step"]
+ bias_correction2 = 1 - beta2 ** state["step"]
+ step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1
+
+ return -step_size * exp_avg / denom
diff --git a/climategan/painter.py b/climategan/painter.py
new file mode 100644
index 0000000000000000000000000000000000000000..739ec2b1bda94a7b37ea17b5d757e009255bd312
--- /dev/null
+++ b/climategan/painter.py
@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import climategan.strings as strings
+from climategan.blocks import InterpolateNearest2d, SPADEResnetBlock
+from climategan.norms import SpectralNorm
+
+
+def create_painter(opts, no_init=False, verbose=0):
+ if verbose > 0:
+ print(" - Add PainterSpadeDecoder Painter")
+ return PainterSpadeDecoder(opts)
+
+
+class PainterSpadeDecoder(nn.Module):
+ def __init__(self, opts):
+ """Create a SPADE-based decoder, which forwards z and the conditioning
+ tensors seg (in the original paper, conditioning is on a semantic map only).
+ All along, z is conditioned on seg. First 3 SpadeResblocks (SRB) do not shrink
+ the channel dimension, and an upsampling is applied after each. Therefore
+ 2 upsamplings at this point. Then, for each remaining upsamplings
+ (w.r.t. spade_n_up), the SRB shrinks channels by 2. Before final conv to get 3
+ channels, the number of channels is therefore:
+ final_nc = channels(z) * 2 ** (spade_n_up - 2)
+ Args:
+ latent_dim (tuple): z's shape (only the number of channels matters)
+ cond_nc (int): conditioning tensor's expected number of channels
+ spade_n_up (int): Number of total upsamplings from z
+ spade_use_spectral_norm (bool): use spectral normalization?
+ spade_param_free_norm (str): norm to use before SPADE de-normalization
+ spade_kernel_size (int): SPADE conv layers' kernel size
+ Returns:
+ [type]: [description]
+ """
+ super().__init__()
+
+ latent_dim = opts.gen.p.latent_dim
+ cond_nc = 3
+ spade_n_up = opts.gen.p.spade_n_up
+ spade_use_spectral_norm = opts.gen.p.spade_use_spectral_norm
+ spade_param_free_norm = opts.gen.p.spade_param_free_norm
+ spade_kernel_size = 3
+
+ self.z_nc = latent_dim
+ self.spade_n_up = spade_n_up
+
+ self.z_h = self.z_w = None
+
+ self.fc = nn.Conv2d(3, latent_dim, 3, padding=1)
+ self.head_0 = SPADEResnetBlock(
+ self.z_nc,
+ self.z_nc,
+ cond_nc,
+ spade_use_spectral_norm,
+ spade_param_free_norm,
+ spade_kernel_size,
+ )
+
+ self.G_middle_0 = SPADEResnetBlock(
+ self.z_nc,
+ self.z_nc,
+ cond_nc,
+ spade_use_spectral_norm,
+ spade_param_free_norm,
+ spade_kernel_size,
+ )
+ self.G_middle_1 = SPADEResnetBlock(
+ self.z_nc,
+ self.z_nc,
+ cond_nc,
+ spade_use_spectral_norm,
+ spade_param_free_norm,
+ spade_kernel_size,
+ )
+
+ self.up_spades = nn.Sequential(
+ *[
+ SPADEResnetBlock(
+ self.z_nc // 2 ** i,
+ self.z_nc // 2 ** (i + 1),
+ cond_nc,
+ spade_use_spectral_norm,
+ spade_param_free_norm,
+ spade_kernel_size,
+ )
+ for i in range(spade_n_up - 2)
+ ]
+ )
+
+ self.final_nc = self.z_nc // 2 ** (spade_n_up - 2)
+
+ self.final_spade = SPADEResnetBlock(
+ self.final_nc,
+ self.final_nc,
+ cond_nc,
+ spade_use_spectral_norm,
+ spade_param_free_norm,
+ spade_kernel_size,
+ )
+ self.final_shortcut = None
+ if opts.gen.p.use_final_shortcut:
+ self.final_shortcut = nn.Sequential(
+ *[
+ SpectralNorm(nn.Conv2d(self.final_nc, 3, 1)),
+ nn.BatchNorm2d(3),
+ nn.LeakyReLU(0.2, True),
+ ]
+ )
+
+ self.conv_img = nn.Conv2d(self.final_nc, 3, 3, padding=1)
+
+ self.upsample = InterpolateNearest2d(scale_factor=2)
+
+ def set_latent_shape(self, shape, is_input=True):
+ """
+ Sets the latent shape to start the upsampling from, i.e. z_h and z_w.
+ If is_input is True, then this is the actual input shape which should
+ be divided by 2 ** spade_n_up
+ Otherwise, just sets z_h and z_w from shape[-2] and shape[-1]
+
+ Args:
+ shape (tuple): The shape to start sampling from.
+ is_input (bool, optional): Whether to divide shape by 2 ** spade_n_up
+ """
+ if isinstance(shape, (list, tuple)):
+ self.z_h = shape[-2]
+ self.z_w = shape[-1]
+ elif isinstance(shape, int):
+ self.z_h = self.z_w = shape
+ else:
+ raise ValueError("Unknown shape type:", shape)
+
+ if is_input:
+ self.z_h = self.z_h // (2 ** self.spade_n_up)
+ self.z_w = self.z_w // (2 ** self.spade_n_up)
+
+ def _apply(self, fn):
+ # print("Applying SpadeDecoder", fn)
+ super()._apply(fn)
+ # self.head_0 = fn(self.head_0)
+ # self.G_middle_0 = fn(self.G_middle_0)
+ # self.G_middle_1 = fn(self.G_middle_1)
+ # for i, up in enumerate(self.up_spades):
+ # self.up_spades[i] = fn(up)
+ # self.conv_img = fn(self.conv_img)
+ return self
+
+ def forward(self, z, cond):
+ if z is None:
+ assert self.z_h is not None and self.z_w is not None
+ z = self.fc(F.interpolate(cond, size=(self.z_h, self.z_w)))
+ y = self.head_0(z, cond)
+ y = self.upsample(y)
+ y = self.G_middle_0(y, cond)
+ y = self.upsample(y)
+ y = self.G_middle_1(y, cond)
+
+ for i, up in enumerate(self.up_spades):
+ y = self.upsample(y)
+ y = up(y, cond)
+
+ if self.final_shortcut is not None:
+ cond = self.final_shortcut(y)
+ y = self.final_spade(y, cond)
+ y = self.conv_img(F.leaky_relu(y, 2e-1))
+ y = torch.tanh(y)
+ return y
+
+ def __str__(self):
+ return strings.spadedecoder(self)
diff --git a/climategan/strings.py b/climategan/strings.py
new file mode 100644
index 0000000000000000000000000000000000000000..37d1af144a7ace94bc07a1e005a7b7d4406f31b6
--- /dev/null
+++ b/climategan/strings.py
@@ -0,0 +1,99 @@
+"""custom __str__ methods for ClimateGAN's classes
+"""
+import torch
+import torch.nn as nn
+
+
+def title(name, color="\033[94m"):
+ name = "==== " + name + " ===="
+ s = "=" * len(name)
+ s = f"{s}\n{name}\n{s}"
+ return f"\033[1m{color}{s}\033[0m"
+
+
+def generator(G):
+ s = title("OmniGenerator", "\033[95m") + "\n"
+
+ s += str(G.encoder) + "\n\n"
+ for d in G.decoders:
+ if d not in {"a", "t"}:
+ s += str(G.decoders[d]) + "\n\n"
+ elif d == "a":
+ s += "[r & s]\n" + str(G.decoders["a"]["r"]) + "\n\n"
+ else:
+ if G.opts.gen.t.use_bit_conditioning:
+ s += "[bit]\n" + str(G.decoders["t"]) + "\n\n"
+ else:
+ s += "[f & n]\n" + str(G.decoders["t"]["f"]) + "\n\n"
+ return s.strip()
+
+
+def encoder(E):
+ s = title("Encoder") + "\n"
+ for b in E.model:
+ s += str(b) + "\n"
+ return s.strip()
+
+
+def get_conv_weight(conv):
+ weight = torch.Tensor(
+ conv.out_channels, conv.in_channels // conv.groups, *conv.kernel_size
+ )
+ return weight.shape
+
+
+def conv2dblock(obj):
+ name = "{:20}".format("Conv2dBlock")
+ s = ""
+ if "SpectralNorm" in obj.conv.__class__.__name__:
+ s = "SpectralNorm => "
+ w = str(tuple(get_conv_weight(obj.conv.module)))
+ else:
+ w = str(tuple(get_conv_weight(obj.conv)))
+ return f"{name}{s}{w}".strip()
+
+
+def resblocks(rb):
+ s = "{}\n".format(f"ResBlocks({len(rb.model)})")
+ for i, r in enumerate(rb.model):
+ s += f" - ({i}) {str(r)}\n"
+ return s.strip()
+
+
+def resblock(rb):
+ s = "{:12}".format("Resblock")
+ return f"{s}{rb.dim} channels, {rb.norm} norm + {rb.activation}"
+
+
+def basedecoder(bd):
+ s = title(bd.__class__.__name__) + "\n"
+ for b in bd.model:
+ if isinstance(b, nn.Upsample) or "InterpolateNearest2d" in b.__class__.__name__:
+ s += "{:20}".format("Upsample") + "x2\n"
+ else:
+ s += str(b) + "\n"
+ return s.strip()
+
+
+def spaderesblock(srb):
+ name = "{:20}".format("SPADEResnetBlock") + f"k {srb.kernel_size}, "
+ s = f"{name}{srb.fin} > {srb.fout}, "
+ s += f"param_free_norm: {srb.param_free_norm}, "
+ s += f"spectral_norm: {srb.use_spectral_norm}"
+ return s.strip()
+
+
+def spadedecoder(sd):
+ s = title(sd.__class__.__name__) + "\n"
+ up = "{:20}x2\n".format("Upsample")
+ s += up
+ s += str(sd.head_0) + "\n"
+ s += up
+ s += str(sd.G_middle_0) + "\n"
+ s += up
+ s += str(sd.G_middle_1) + "\n"
+ for i, u in enumerate(sd.up_spades):
+ s += up
+ s += str(u) + "\n"
+ s += "{:20}".format("Conv2d") + str(tuple(get_conv_weight(sd.conv_img))) + " tanh"
+ return s
diff --git a/climategan/trainer.py b/climategan/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..95688cf98b7e293d4f3d3f1155ec0083d773d082
--- /dev/null
+++ b/climategan/trainer.py
@@ -0,0 +1,1939 @@
+"""
+Main component: the trainer handles everything:
+ * initializations
+ * training
+ * saving
+"""
+import inspect
+import warnings
+from copy import deepcopy
+from pathlib import Path
+from time import time
+
+import numpy as np
+from comet_ml import ExistingExperiment, Experiment
+
+warnings.simplefilter("ignore", UserWarning)
+
+import torch
+import torch.nn as nn
+from addict import Dict
+from torch import autograd, sigmoid, softmax
+from torch.cuda.amp import GradScaler, autocast
+from tqdm import tqdm
+
+from climategan.data import get_all_loaders
+from climategan.discriminator import OmniDiscriminator, create_discriminator
+from climategan.eval_metrics import accuracy, mIOU
+from climategan.fid import compute_val_fid
+from climategan.fire import add_fire
+from climategan.generator import OmniGenerator, create_generator
+from climategan.logger import Logger
+from climategan.losses import get_losses
+from climategan.optim import get_optimizer
+from climategan.transforms import DiffTransforms
+from climategan.tutils import (
+ divide_pred,
+ get_num_params,
+ get_WGAN_gradient,
+ lrgb2srgb,
+ normalize,
+ print_num_parameters,
+ shuffle_batch_tuple,
+ srgb2lrgb,
+ vgg_preprocess,
+ zero_grad,
+)
+from climategan.utils import (
+ comet_kwargs,
+ div_dict,
+ find_target_size,
+ flatten_opts,
+ get_display_indices,
+ get_existing_comet_id,
+ get_latest_opts,
+ merge,
+ resolve,
+ sum_dict,
+ Timer,
+)
+
+try:
+ import torch_xla.core.xla_model as xm # type: ignore
+except ImportError:
+ pass
+
+
+class Trainer:
+ """Main trainer class"""
+
+ def __init__(self, opts, comet_exp=None, verbose=0, device=None):
+ """Trainer class to gather various model training procedures
+ such as training evaluating saving and logging
+
+ init:
+ * creates an addict.Dict logger
+ * creates logger.exp as a comet_exp experiment if `comet` arg is True
+ * sets the device (1 GPU or CPU)
+
+ Args:
+ opts (addict.Dict): options to configure the trainer, the data, the models
+ comet (bool, optional): whether to log the trainer with comet.ml.
+ Defaults to False.
+ verbose (int, optional): printing level to debug. Defaults to 0.
+ """
+ super().__init__()
+
+ self.opts = opts
+ self.verbose = verbose
+ self.logger = Logger(self)
+
+ self.losses = None
+ self.G = self.D = None
+ self.real_val_fid_stats = None
+ self.use_pl4m = False
+ self.is_setup = False
+ self.loaders = self.all_loaders = None
+ self.exp = None
+
+ self.current_mode = "train"
+ self.diff_transforms = None
+ self.kitti_pretrain = self.opts.train.kitti.pretrain
+ self.pseudo_training_tasks = set(self.opts.train.pseudo.tasks)
+
+ self.lr_names = {}
+ self.base_display_images = {}
+ self.kitty_display_images = {}
+ self.domain_labels = {"s": 0, "r": 1}
+
+ self.device = device or torch.device(
+ "cuda:0" if torch.cuda.is_available() else "cpu"
+ )
+
+ if isinstance(comet_exp, Experiment):
+ self.exp = comet_exp
+
+ if self.opts.train.amp:
+ optimizers = [
+ self.opts.gen.opt.optimizer.lower(),
+ self.opts.dis.opt.optimizer.lower(),
+ ]
+ if "extraadam" in optimizers:
+ raise ValueError(
+ "AMP does not work with ExtraAdam ({})".format(optimizers)
+ )
+ self.grad_scaler_d = GradScaler()
+ self.grad_scaler_g = GradScaler()
+
+ # -------------------------------
+ # ----- Legacy Overwrites -----
+ # -------------------------------
+ if (
+ self.opts.gen.s.depth_feat_fusion is True
+ or self.opts.gen.s.depth_dada_fusion is True
+ ):
+ self.opts.gen.s.use_dada = True
+
+ @torch.no_grad()
+ def paint_and_mask(self, image_batch, mask_batch=None, resolution="approx"):
+ """
+ Paints a batch of images (or a single image with a batch dim of 1). If
+ masks are not provided, they are inferred from the masker.
+ Resolution can either be the train-time resolution or the closest
+ multiple of 2 ** spade_n_up
+
+ Operations performed without gradient
+
+ If resolution == "approx" then the output image has the shape:
+ (dim // 2 ** spade_n_up) * 2 ** spade_n_up, for dim in [height, width]
+ eg: (1000, 1300) => (896, 1280) for spade_n_up = 7
+ If resolution == "exact" then the output image has the same shape:
+ we first process in "approx" mode then upsample bilinear
+ If resolution == "basic" image output shape is the train-time's
+ (typically 640x640)
+ If resolution == "upsample" image is inferred as "basic" and
+ then upsampled to original size
+
+ Args:
+ image_batch (torch.Tensor): 4D batch of images to flood
+ mask_batch (torch.Tensor, optional): Masks for the images.
+ Defaults to None (infer with Masker).
+ resolution (str, optional): "approx", "exact" or False
+
+ Returns:
+ torch.Tensor: N x C x H x W where H and W depend on `resolution`
+ """
+ assert resolution in {"approx", "exact", "basic", "upsample"}
+ previous_mode = self.current_mode
+ if previous_mode == "train":
+ self.eval_mode()
+
+ if mask_batch is None:
+ mask_batch = self.G.mask(x=image_batch)
+ else:
+ assert len(image_batch) == len(mask_batch)
+ assert image_batch.shape[-2:] == mask_batch.shape[-2:]
+
+ if resolution not in {"approx", "exact"}:
+ painted = self.G.paint(mask_batch, image_batch)
+
+ if resolution == "upsample":
+ painted = nn.functional.interpolate(
+ painted, size=image_batch.shape[-2:], mode="bilinear"
+ )
+ else:
+ # save latent shape
+ zh = self.G.painter.z_h
+ zw = self.G.painter.z_w
+ # adapt latent shape to approximately keep the resolution
+ self.G.painter.z_h = (
+ image_batch.shape[-2] // 2**self.opts.gen.p.spade_n_up
+ )
+ self.G.painter.z_w = (
+ image_batch.shape[-1] // 2**self.opts.gen.p.spade_n_up
+ )
+
+ painted = self.G.paint(mask_batch, image_batch)
+
+ self.G.painter.z_h = zh
+ self.G.painter.z_w = zw
+ if resolution == "exact":
+ painted = nn.functional.interpolate(
+ painted, size=image_batch.shape[-2:], mode="bilinear"
+ )
+
+ if previous_mode == "train":
+ self.train_mode()
+
+ return painted
+
+ def _p(self, *args, **kwargs):
+ """
+ verbose-dependant print util
+ """
+ if self.verbose > 0:
+ print(*args, **kwargs)
+
+ @torch.no_grad()
+ def infer_all(
+ self,
+ x,
+ numpy=True,
+ stores={},
+ bin_value=-1,
+ half=False,
+ xla=False,
+ cloudy=False,
+ auto_resize_640=False,
+ ignore_event=set(),
+ return_masks=False,
+ ):
+ """
+ Create a dictionnary of events from a numpy or tensor,
+ single or batch image data.
+
+ stores is a dictionnary of times for the Timer class.
+
+ bin_value is used to binarize (or not) flood masks
+ """
+ assert self.is_setup
+ assert len(x.shape) in {3, 4}, f"Unknown Data shape {x.shape}"
+
+ # convert numpy to tensor
+ if not isinstance(x, torch.Tensor):
+ x = torch.tensor(x, device=self.device)
+
+ # add batch dimension
+ if len(x.shape) == 3:
+ x.unsqueeze_(0)
+
+ # permute channels as second dimension
+ if x.shape[1] != 3:
+ assert x.shape[-1] == 3, f"Unknown x shape to permute {x.shape}"
+ x = x.permute(0, 3, 1, 2)
+
+ # send to device
+ if x.device != self.device:
+ x = x.to(self.device)
+
+ # interpolate to standard input size
+ if auto_resize_640 and (x.shape[-1] != 640 or x.shape[-2] != 640):
+ x = torch.nn.functional.interpolate(x, (640, 640), mode="bilinear")
+
+ if half:
+ x = x.half()
+
+ # adjust painter's latent vector
+ self.G.painter.set_latent_shape(x.shape, True)
+
+ with Timer(store=stores.get("all events", [])):
+ # encode
+ with Timer(store=stores.get("encode", [])):
+ z = self.G.encode(x)
+ if xla:
+ xm.mark_step()
+
+ # predict from masker
+ with Timer(store=stores.get("depth", [])):
+ depth, z_depth = self.G.decoders["d"](z)
+ if xla:
+ xm.mark_step()
+ with Timer(store=stores.get("segmentation", [])):
+ segmentation = self.G.decoders["s"](z, z_depth)
+ if xla:
+ xm.mark_step()
+ with Timer(store=stores.get("mask", [])):
+ cond = self.G.make_m_cond(depth, segmentation, x)
+ mask = self.G.mask(z=z, cond=cond, z_depth=z_depth)
+ if xla:
+ xm.mark_step()
+
+ # apply events
+ if "wildfire" not in ignore_event:
+ with Timer(store=stores.get("wildfire", [])):
+ wildfire = self.compute_fire(x, seg_preds=segmentation)
+ if "smog" not in ignore_event:
+ with Timer(store=stores.get("smog", [])):
+ smog = self.compute_smog(x, d=depth, s=segmentation)
+ if "flood" not in ignore_event:
+ with Timer(store=stores.get("flood", [])):
+ flood = self.compute_flood(
+ x,
+ m=mask,
+ s=segmentation,
+ cloudy=cloudy,
+ bin_value=bin_value,
+ )
+
+ if xla:
+ xm.mark_step()
+
+ if numpy:
+ with Timer(store=stores.get("numpy", [])):
+ # normalize to 0-1
+ flood = normalize(flood).cpu()
+ smog = normalize(smog).cpu()
+ wildfire = normalize(wildfire).cpu()
+
+ # convert to numpy
+ flood = flood.permute(0, 2, 3, 1).numpy()
+ smog = smog.permute(0, 2, 3, 1).numpy()
+ wildfire = wildfire.permute(0, 2, 3, 1).numpy()
+
+ # convert to 0-255 uint8
+ flood = (flood * 255).astype(np.uint8)
+ smog = (smog * 255).astype(np.uint8)
+ wildfire = (wildfire * 255).astype(np.uint8)
+
+ output_data = {"flood": flood, "wildfire": wildfire, "smog": smog}
+ if return_masks:
+ output_data["mask"] = (
+ ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
+ )
+
+ return output_data
+
+ @classmethod
+ def resume_from_path(
+ cls,
+ path,
+ overrides={},
+ setup=True,
+ inference=False,
+ new_exp=False,
+ device=None,
+ verbose=1,
+ ):
+ """
+ Resume and optionally setup a trainer from a specific path,
+ using the latest opts and checkpoint. Requires path to contain opts.yaml
+ (or increased), url.txt (or increased) and checkpoints/
+
+ Args:
+ path (str | pathlib.Path): Trainer to resume
+ overrides (dict, optional): Override loaded opts with those. Defaults to {}.
+ setup (bool, optional): Wether or not to setup the trainer before
+ returning it. Defaults to True.
+ inference (bool, optional): Setup should be done in inference mode or not.
+ Defaults to False.
+ new_exp (bool, optional): Re-use existing comet exp in path or create
+ a new one? Defaults to False.
+ device (torch.device, optional): Device to use
+
+ Returns:
+ climategan.Trainer: Loaded and resumed trainer
+ """
+ p = resolve(path)
+ assert p.exists()
+
+ c = p / "checkpoints"
+ assert c.exists() and c.is_dir()
+
+ opts = get_latest_opts(p)
+ opts = Dict(merge(overrides, opts))
+ opts.train.resume = True
+
+ if new_exp is None:
+ exp = None
+ elif new_exp is True:
+ exp = Experiment(project_name="climategan", **comet_kwargs)
+ exp.log_asset_folder(
+ str(resolve(Path(__file__)).parent),
+ recursive=True,
+ log_file_name=True,
+ )
+ exp.log_parameters(flatten_opts(opts))
+ else:
+ comet_id = get_existing_comet_id(p)
+ exp = ExistingExperiment(previous_experiment=comet_id, **comet_kwargs)
+
+ trainer = cls(opts, comet_exp=exp, device=device, verbose=verbose)
+
+ if setup:
+ trainer.setup(inference=inference)
+ return trainer
+
+ def save(self):
+ save_dir = Path(self.opts.output_path) / Path("checkpoints")
+ save_dir.mkdir(exist_ok=True)
+ save_path = save_dir / "latest_ckpt.pth"
+
+ # Construct relevant state dicts / optims:
+ # Save at least G
+ save_dict = {
+ "epoch": self.logger.epoch,
+ "G": self.G.state_dict(),
+ "g_opt": self.g_opt.state_dict(),
+ "step": self.logger.global_step,
+ }
+
+ if self.D is not None and get_num_params(self.D) > 0:
+ save_dict["D"] = self.D.state_dict()
+ save_dict["d_opt"] = self.d_opt.state_dict()
+
+ if (
+ self.logger.epoch >= self.opts.train.min_save_epoch
+ and self.logger.epoch % self.opts.train.save_n_epochs == 0
+ ):
+ torch.save(save_dict, save_dir / f"epoch_{self.logger.epoch}_ckpt.pth")
+
+ torch.save(save_dict, save_path)
+
+ def resume(self, inference=False):
+ tpu = "xla" in str(self.device)
+ if tpu:
+ print("Resuming on TPU:", self.device)
+
+ m_path = Path(self.opts.load_paths.m)
+ p_path = Path(self.opts.load_paths.p)
+ pm_path = Path(self.opts.load_paths.pm)
+ output_path = Path(self.opts.output_path)
+
+ map_loc = self.device if not tpu else "cpu"
+
+ if "m" in self.opts.tasks and "p" in self.opts.tasks:
+ # ----------------------------------------
+ # ----- Masker and Painter Loading -----
+ # ----------------------------------------
+
+ # want to resume a pm model but no path was provided:
+ # resume a single pm model from output_path
+ if all([str(p) == "none" for p in [m_path, p_path, pm_path]]):
+ checkpoint_path = output_path / "checkpoints/latest_ckpt.pth"
+ print("Resuming P+M model from", str(checkpoint_path))
+ checkpoint = torch.load(checkpoint_path, map_location=map_loc)
+
+ # want to resume a pm model with a pm_path provided:
+ # resume a single pm model from load_paths.pm
+ # depending on whether a dir or a file is specified
+ elif str(pm_path) != "none":
+ assert pm_path.exists()
+
+ if pm_path.is_dir():
+ checkpoint_path = pm_path / "checkpoints/latest_ckpt.pth"
+ else:
+ assert pm_path.suffix == ".pth"
+ checkpoint_path = pm_path
+
+ print("Resuming P+M model from", str(checkpoint_path))
+ checkpoint = torch.load(checkpoint_path, map_location=map_loc)
+
+ # want to resume a pm model, pm_path not provided:
+ # m_path and p_path must be provided as dirs or pth files
+ elif m_path != p_path:
+ assert m_path.exists()
+ assert p_path.exists()
+
+ if m_path.is_dir():
+ m_path = m_path / "checkpoints/latest_ckpt.pth"
+
+ if p_path.is_dir():
+ p_path = p_path / "checkpoints/latest_ckpt.pth"
+
+ assert m_path.suffix == ".pth"
+ assert p_path.suffix == ".pth"
+
+ print(f"Resuming P+M model from \n -{p_path} \nand \n -{m_path}")
+ m_checkpoint = torch.load(m_path, map_location=map_loc)
+ p_checkpoint = torch.load(p_path, map_location=map_loc)
+ checkpoint = merge(m_checkpoint, p_checkpoint)
+
+ else:
+ raise ValueError(
+ "Cannot resume a P+M model with provided load_paths:\n{}".format(
+ self.opts.load_paths
+ )
+ )
+
+ else:
+ # ----------------------------------
+ # ----- Single Model Loading -----
+ # ----------------------------------
+
+ # cannot specify both paths
+ if str(m_path) != "none" and str(p_path) != "none":
+ raise ValueError(
+ "Opts tasks are {} but received 2 values for the load_paths".format(
+ self.opts.tasks
+ )
+ )
+
+ # specified m
+ elif str(m_path) != "none":
+ assert m_path.exists()
+ assert "m" in self.opts.tasks
+ model = "M"
+ if m_path.is_dir():
+ m_path = m_path / "checkpoints/latest_ckpt.pth"
+ checkpoint_path = m_path
+
+ # specified m
+ elif str(p_path) != "none":
+ assert p_path.exists()
+ assert "p" in self.opts.tasks
+ model = "P"
+ if p_path.is_dir():
+ p_path = p_path / "checkpoints/latest_ckpt.pth"
+ checkpoint_path = p_path
+
+ # specified neither p nor m: resume from output_path
+ else:
+ model = "P" if "p" in self.opts.tasks else "M"
+ checkpoint_path = output_path / "checkpoints/latest_ckpt.pth"
+
+ print(f"Resuming {model} model from {checkpoint_path}")
+ checkpoint = torch.load(checkpoint_path, map_location=map_loc)
+
+ # On TPUs must send the data to the xla device as it cannot be mapped
+ # there directly from torch.load
+ if tpu:
+ checkpoint = xm.send_cpu_data_to_device(checkpoint, self.device)
+
+ # -----------------------
+ # ----- Restore G -----
+ # -----------------------
+ if inference:
+ incompatible_keys = self.G.load_state_dict(checkpoint["G"], strict=False)
+ if incompatible_keys.missing_keys:
+ print("WARNING: Missing keys in self.G.load_state_dict, keeping inits")
+ print(incompatible_keys.missing_keys)
+ if incompatible_keys.unexpected_keys:
+ print("WARNING: Ignoring Unexpected keys in self.G.load_state_dict")
+ print(incompatible_keys.unexpected_keys)
+ else:
+ self.G.load_state_dict(checkpoint["G"])
+
+ if inference:
+ # only G is needed to infer
+ print("Done loading checkpoints.")
+ return
+
+ self.g_opt.load_state_dict(checkpoint["g_opt"])
+
+ # ------------------------------
+ # ----- Resume scheduler -----
+ # ------------------------------
+ # https://discuss.pytorch.org/t/a-problem-occured-when-resuming-an-optimizer/28822
+ for _ in range(self.logger.epoch + 1):
+ self.update_learning_rates()
+
+ # -----------------------
+ # ----- Restore D -----
+ # -----------------------
+ if self.D is not None and get_num_params(self.D) > 0:
+ self.D.load_state_dict(checkpoint["D"])
+ self.d_opt.load_state_dict(checkpoint["d_opt"])
+
+ # ---------------------------
+ # ----- Resore logger -----
+ # ---------------------------
+ self.logger.epoch = checkpoint["epoch"]
+ self.logger.global_step = checkpoint["step"]
+ self.exp.log_text(
+ "Resuming from epoch {} & step {}".format(
+ checkpoint["epoch"], checkpoint["step"]
+ )
+ )
+ # Round step to even number for extraGradient
+ if self.logger.global_step % 2 != 0:
+ self.logger.global_step += 1
+
+ def eval_mode(self):
+ """
+ Set trainer's models in eval mode
+ """
+ if self.G is not None:
+ self.G.eval()
+ if self.D is not None:
+ self.D.eval()
+ self.current_mode = "eval"
+
+ def train_mode(self):
+ """
+ Set trainer's models in train mode
+ """
+ if self.G is not None:
+ self.G.train()
+ if self.D is not None:
+ self.D.train()
+
+ self.current_mode = "train"
+
+ def assert_z_matches_x(self, x, z):
+ assert x.shape[0] == (
+ z.shape[0] if not isinstance(z, (list, tuple)) else z[0].shape[0]
+ ), "x-> {}, z->{}".format(
+ x.shape, z.shape if not isinstance(z, (list, tuple)) else z[0].shape
+ )
+
+ def batch_to_device(self, b):
+ """sends the data in b to self.device
+
+ Args:
+ b (dict): the batch dictionnay
+
+ Returns:
+ dict: the batch dictionnary with its "data" field sent to self.device
+ """
+ for task, tensor in b["data"].items():
+ b["data"][task] = tensor.to(self.device)
+ return b
+
+ def sample_painter_z(self, batch_size):
+ return self.G.sample_painter_z(batch_size, self.device)
+
+ @property
+ def train_loaders(self):
+ """Get a zip of all training loaders
+
+ Returns:
+ generator: zip generator yielding tuples:
+ (batch_rf, batch_rn, batch_sf, batch_sn)
+ """
+ return zip(*list(self.loaders["train"].values()))
+
+ @property
+ def val_loaders(self):
+ """Get a zip of all validation loaders
+
+ Returns:
+ generator: zip generator yielding tuples:
+ (batch_rf, batch_rn, batch_sf, batch_sn)
+ """
+ return zip(*list(self.loaders["val"].values()))
+
+ def compute_latent_shape(self):
+ """Compute the latent shape, i.e. the Encoder's output shape,
+ from a batch.
+
+ Raises:
+ ValueError: If no loader, the latent_shape cannot be inferred
+
+ Returns:
+ tuple: (c, h, w)
+ """
+ x = None
+ for mode in self.all_loaders:
+ for domain in self.all_loaders.loaders[mode]:
+ x = (
+ self.all_loaders[mode][domain]
+ .dataset[0]["data"]["x"]
+ .to(self.device)
+ )
+ break
+ if x is not None:
+ break
+
+ if x is None:
+ raise ValueError("No batch found to compute_latent_shape")
+
+ x = x.unsqueeze(0)
+ z = self.G.encode(x)
+ return z.shape[1:] if not isinstance(z, (list, tuple)) else z[0].shape[1:]
+
+ def g_opt_step(self):
+ """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation
+ step every other step
+ """
+ if "extra" in self.opts.gen.opt.optimizer.lower() and (
+ self.logger.global_step % 2 == 0
+ ):
+ self.g_opt.extrapolation()
+ else:
+ self.g_opt.step()
+
+ def d_opt_step(self):
+ """Run an optimizing step ; if using ExtraAdam, there needs to be an extrapolation
+ step every other step
+ """
+ if "extra" in self.opts.dis.opt.optimizer.lower() and (
+ self.logger.global_step % 2 == 0
+ ):
+ self.d_opt.extrapolation()
+ else:
+ self.d_opt.step()
+
+ def update_learning_rates(self):
+ if self.g_scheduler is not None:
+ self.g_scheduler.step()
+ if self.d_scheduler is not None:
+ self.d_scheduler.step()
+
+ def setup(self, inference=False):
+ """Prepare the trainer before it can be used to train the models:
+ * initialize G and D
+ * creates 2 optimizers
+ """
+ self.logger.global_step = 0
+ start_time = time()
+ self.logger.time.start_time = start_time
+ verbose = self.verbose
+
+ if not inference:
+ self.all_loaders = get_all_loaders(self.opts)
+
+ # -----------------------
+ # ----- Generator -----
+ # -----------------------
+ __t = time()
+ print("Creating generator...")
+
+ self.G: OmniGenerator = create_generator(
+ self.opts, device=self.device, no_init=inference, verbose=verbose
+ )
+
+ self.has_painter = get_num_params(self.G.painter) or self.G.load_val_painter()
+
+ if self.has_painter:
+ self.G.painter.set_latent_shape(find_target_size(self.opts, "x"), True)
+
+ print(f"Generator OK in {time() - __t:.1f}s.")
+
+ if inference: # Inference mode: no more than a Generator needed
+ print("Inference mode: no Discriminator, no optimizers")
+ print_num_parameters(self)
+ self.switch_data(to="base")
+ if self.opts.train.resume:
+ self.resume(True)
+ self.eval_mode()
+ print("Trainer is in evaluation mode.")
+ print("Setup done.")
+ self.is_setup = True
+ return
+
+ # ---------------------------
+ # ----- Discriminator -----
+ # ---------------------------
+
+ self.D: OmniDiscriminator = create_discriminator(
+ self.opts, self.device, verbose=verbose
+ )
+ print("Discriminator OK.")
+
+ print_num_parameters(self)
+
+ # --------------------------
+ # ----- Optimization -----
+ # --------------------------
+ # Get different optimizers for each task (different learning rates)
+ self.g_opt, self.g_scheduler, self.lr_names["G"] = get_optimizer(
+ self.G, self.opts.gen.opt, self.opts.tasks
+ )
+
+ if get_num_params(self.D) > 0:
+ self.d_opt, self.d_scheduler, self.lr_names["D"] = get_optimizer(
+ self.D, self.opts.dis.opt, self.opts.tasks, True
+ )
+ else:
+ self.d_opt, self.d_scheduler = None, None
+
+ self.losses = get_losses(self.opts, verbose, device=self.device)
+
+ if "p" in self.opts.tasks and self.opts.gen.p.diff_aug.use:
+ self.diff_transforms = DiffTransforms(self.opts.gen.p.diff_aug)
+
+ if verbose > 0:
+ for mode, mode_dict in self.all_loaders.items():
+ for domain, domain_loader in mode_dict.items():
+ print(
+ "Loader {} {} : {}".format(
+ mode, domain, len(domain_loader.dataset)
+ )
+ )
+
+ # ----------------------------
+ # ----- Display images -----
+ # ----------------------------
+ self.set_display_images()
+
+ # -------------------------------
+ # ----- Log Architectures -----
+ # -------------------------------
+ self.logger.log_architecture()
+
+ # -----------------------------
+ # ----- Set data source -----
+ # -----------------------------
+ if self.kitti_pretrain:
+ self.switch_data(to="kitti")
+ else:
+ self.switch_data(to="base")
+
+ # -------------------------
+ # ----- Setup Done. -----
+ # -------------------------
+ print(" " * 50, end="\r")
+ print("Done creating display images")
+
+ if self.opts.train.resume:
+ print("Resuming Model (inference: False)")
+ self.resume(False)
+ else:
+ print("Not resuming: starting a new model")
+
+ print("Setup done.")
+ self.is_setup = True
+
+ def switch_data(self, to="kitti"):
+ caller = inspect.stack()[1].function
+ print(f"[{caller}] Switching data source to", to)
+ self.data_source = to
+ if to == "kitti":
+ self.display_images = self.kitty_display_images
+ if self.all_loaders is not None:
+ self.loaders = {
+ mode: {"s": self.all_loaders[mode]["kitti"]}
+ for mode in self.all_loaders
+ }
+ else:
+ self.display_images = self.base_display_images
+ if self.all_loaders is not None:
+ self.loaders = {
+ mode: {
+ domain: self.all_loaders[mode][domain]
+ for domain in self.all_loaders[mode]
+ if domain != "kitti"
+ }
+ for mode in self.all_loaders
+ }
+ if (
+ self.logger.global_step % 2 != 0
+ and "extra" in self.opts.dis.opt.optimizer.lower()
+ ):
+ print(
+ "Warning: artificially bumping step to run an extrapolation step first."
+ )
+ self.logger.global_step += 1
+
+ def set_display_images(self, use_all=False):
+ for mode, mode_dict in self.all_loaders.items():
+
+ if self.kitti_pretrain:
+ self.kitty_display_images[mode] = {}
+ self.base_display_images[mode] = {}
+
+ for domain in mode_dict:
+
+ if self.kitti_pretrain and domain == "kitti":
+ target_dict = self.kitty_display_images
+ else:
+ if domain == "kitti":
+ continue
+ target_dict = self.base_display_images
+
+ dataset = self.all_loaders[mode][domain].dataset
+ display_indices = (
+ get_display_indices(self.opts, domain, len(dataset))
+ if not use_all
+ else list(range(len(dataset)))
+ )
+ ldis = len(display_indices)
+ print(
+ f" Creating {ldis} {mode} {domain} display images...",
+ end="\r",
+ flush=True,
+ )
+ target_dict[mode][domain] = [
+ Dict(dataset[i])
+ for i in display_indices
+ if (print(f"({i})", end="\r") is None and i < len(dataset))
+ ]
+ if self.exp is not None:
+ for im_id, d in enumerate(target_dict[mode][domain]):
+ self.exp.log_parameter(
+ "display_image_{}_{}_{}".format(mode, domain, im_id),
+ d["paths"],
+ )
+
+ def train(self):
+ """For each epoch:
+ * train
+ * eval
+ * save
+ """
+ assert self.is_setup
+
+ for self.logger.epoch in range(
+ self.logger.epoch, self.logger.epoch + self.opts.train.epochs
+ ):
+ # backprop painter's disc loss to masker
+ if (
+ self.logger.epoch == self.opts.gen.p.pl4m_epoch
+ and get_num_params(self.G.painter) > 0
+ and "p" in self.opts.tasks
+ and self.opts.gen.m.use_pl4m
+ ):
+ print(
+ "\n\n >>> Enabling pl4m at epoch {}\n\n".format(self.logger.epoch)
+ )
+ self.use_pl4m = True
+
+ self.run_epoch()
+ self.run_evaluation(verbose=1)
+ self.save()
+
+ # end vkitti2 pre-training
+ if self.logger.epoch == self.opts.train.kitti.epochs - 1:
+ self.switch_data(to="base")
+ self.kitti_pretrain = False
+
+ # end pseudo training
+ if self.logger.epoch == self.opts.train.pseudo.epochs - 1:
+ self.pseudo_training_tasks = set()
+
+ def run_epoch(self):
+ """Runs an epoch:
+ * checks trainer is setup
+ * gets a tuple of batches per domain
+ * sends batches to device
+ * updates sequentially G, D
+ """
+ assert self.is_setup
+ self.train_mode()
+ if self.exp is not None:
+ self.exp.log_parameter("epoch", self.logger.epoch)
+ epoch_len = min(len(loader) for loader in self.loaders["train"].values())
+ epoch_desc = "Epoch {}".format(self.logger.epoch)
+ self.logger.time.epoch_start = time()
+
+ for multi_batch_tuple in tqdm(
+ self.train_loaders,
+ desc=epoch_desc,
+ total=epoch_len,
+ mininterval=0.5,
+ unit="batch",
+ ):
+
+ self.logger.time.step_start = time()
+ multi_batch_tuple = shuffle_batch_tuple(multi_batch_tuple)
+
+ # The `[0]` is because the domain is contained in a list
+ multi_domain_batch = {
+ batch["domain"][0]: self.batch_to_device(batch)
+ for batch in multi_batch_tuple
+ }
+ # ------------------------------
+ # ----- Update Generator -----
+ # ------------------------------
+
+ # freeze params of the discriminator
+ if self.d_opt is not None:
+ for param in self.D.parameters():
+ param.requires_grad = False
+
+ self.update_G(multi_domain_batch)
+
+ # ----------------------------------
+ # ----- Update Discriminator -----
+ # ----------------------------------
+
+ # unfreeze params of the discriminator
+ if self.d_opt is not None and not self.kitti_pretrain:
+ for param in self.D.parameters():
+ param.requires_grad = True
+
+ self.update_D(multi_domain_batch)
+
+ # -------------------------
+ # ----- Log Metrics -----
+ # -------------------------
+ self.logger.global_step += 1
+ self.logger.log_step_time(time())
+
+ if not self.kitti_pretrain:
+ self.update_learning_rates()
+
+ self.logger.log_learning_rates()
+ self.logger.log_epoch_time(time())
+
+ def update_G(self, multi_domain_batch, verbose=0):
+ """Perform an update on g from multi_domain_batch which is a dictionary
+ domain => batch
+
+ * automatic mixed precision according to self.opts.train.amp
+ * compute loss for each task
+ * loss.backward()
+ * g_opt_step()
+ * g_opt.step() or .extrapolation() depending on self.logger.global_step
+ * logs losses on comet.ml with self.logger.log_losses(model_to_update="G")
+
+ Args:
+ multi_domain_batch (dict): dictionnary of domain batches
+ """
+ zero_grad(self.G)
+ if self.opts.train.amp:
+ with autocast():
+ g_loss = self.get_G_loss(multi_domain_batch, verbose)
+ self.grad_scaler_g.scale(g_loss).backward()
+ self.grad_scaler_g.step(self.g_opt)
+ self.grad_scaler_g.update()
+ else:
+ g_loss = self.get_G_loss(multi_domain_batch, verbose)
+ g_loss.backward()
+ self.g_opt_step()
+
+ self.logger.log_losses(model_to_update="G", mode="train")
+
+ def update_D(self, multi_domain_batch, verbose=0):
+ zero_grad(self.D)
+
+ if self.opts.train.amp:
+ with autocast():
+ d_loss = self.get_D_loss(multi_domain_batch, verbose)
+ self.grad_scaler_d.scale(d_loss).backward()
+ self.grad_scaler_d.step(self.d_opt)
+ self.grad_scaler_d.update()
+ else:
+ d_loss = self.get_D_loss(multi_domain_batch, verbose)
+ d_loss.backward()
+ self.d_opt_step()
+
+ self.logger.losses.disc.total_loss = d_loss.item()
+ self.logger.log_losses(model_to_update="D", mode="train")
+
+ def get_D_loss(self, multi_domain_batch, verbose=0):
+ """Compute the discriminators' losses:
+
+ * for each domain-specific batch:
+ * encode the image
+ * get the conditioning tensor if using spade
+ * source domain is the data's domain, sequentially r|s then f|n
+ * get the target domain accordingly
+ * compute the translated image from the data
+ * compute the source domain discriminator's loss on the data
+ * compute the target domain discriminator's loss on the translated image
+
+ # ? In this setting, each D[decoder][domain] is updated twice towards
+ # real or fake data
+
+ See readme's update d section for details
+
+ Args:
+ multi_domain_batch ([type]): [description]
+
+ Returns:
+ [type]: [description]
+ """
+
+ disc_loss = {
+ "m": {"Advent": 0},
+ "s": {"Advent": 0},
+ }
+ if self.opts.dis.p.use_local_discriminator:
+ disc_loss["p"] = {"global": 0, "local": 0}
+ else:
+ disc_loss["p"] = {"gan": 0}
+
+ for domain, batch in multi_domain_batch.items():
+ x = batch["data"]["x"]
+
+ # ---------------------
+ # ----- Painter -----
+ # ---------------------
+ if domain == "rf" and self.has_painter:
+ m = batch["data"]["m"]
+ # sample vector
+ with torch.no_grad():
+ # see spade compute_discriminator_loss
+ fake = self.G.paint(m, x)
+ if self.opts.gen.p.diff_aug.use:
+ fake = self.diff_transforms(fake)
+ x = self.diff_transforms(x)
+ fake = fake.detach()
+ fake.requires_grad_()
+
+ if self.opts.dis.p.use_local_discriminator:
+ fake_d_global = self.D["p"]["global"](fake)
+ real_d_global = self.D["p"]["global"](x)
+
+ fake_d_local = self.D["p"]["local"](fake * m)
+ real_d_local = self.D["p"]["local"](x * m)
+
+ global_loss = self.losses["D"]["p"](fake_d_global, False, True)
+ global_loss += self.losses["D"]["p"](real_d_global, True, True)
+
+ local_loss = self.losses["D"]["p"](fake_d_local, False, True)
+ local_loss += self.losses["D"]["p"](real_d_local, True, True)
+
+ disc_loss["p"]["global"] += global_loss
+ disc_loss["p"]["local"] += local_loss
+ else:
+ real_cat = torch.cat([m, x], axis=1)
+ fake_cat = torch.cat([m, fake], axis=1)
+ real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)
+ real_fake_d = self.D["p"](real_fake_cat)
+ real_d, fake_d = divide_pred(real_fake_d)
+ disc_loss["p"]["gan"] = self.losses["D"]["p"](fake_d, False, True)
+ disc_loss["p"]["gan"] += self.losses["D"]["p"](real_d, True, True)
+
+ # --------------------
+ # ----- Masker -----
+ # --------------------
+ else:
+ z = self.G.encode(x)
+ s_pred = d_pred = cond = z_depth = None
+
+ if "s" in batch["data"]:
+ if "d" in self.opts.tasks and self.opts.gen.s.use_dada:
+ d_pred, z_depth = self.G.decoders["d"](z)
+
+ step_loss, s_pred = self.masker_s_loss(
+ x, z, d_pred, z_depth, None, domain, for_="D"
+ )
+ step_loss *= self.opts.train.lambdas.advent.adv_main
+ disc_loss["s"]["Advent"] += step_loss
+
+ if "m" in batch["data"]:
+ if "d" in self.opts.tasks:
+ if self.opts.gen.m.use_spade:
+ if d_pred is None:
+ d_pred, z_depth = self.G.decoders["d"](z)
+ cond = self.G.make_m_cond(d_pred, s_pred, x)
+ elif self.opts.gen.m.use_dada:
+ if d_pred is None:
+ d_pred, z_depth = self.G.decoders["d"](z)
+
+ step_loss, _ = self.masker_m_loss(
+ x,
+ z,
+ None,
+ domain,
+ for_="D",
+ cond=cond,
+ z_depth=z_depth,
+ depth_preds=d_pred,
+ )
+ step_loss *= self.opts.train.lambdas.advent.adv_main
+ disc_loss["m"]["Advent"] += step_loss
+
+ self.logger.losses.disc.update(
+ {
+ dom: {
+ k: v.item() if isinstance(v, torch.Tensor) else v
+ for k, v in d.items()
+ }
+ for dom, d in disc_loss.items()
+ }
+ )
+
+ loss = sum(v for d in disc_loss.values() for k, v in d.items())
+ return loss
+
+ def get_G_loss(self, multi_domain_batch, verbose=0):
+ m_loss = p_loss = None
+
+ # For now, always compute "representation loss"
+ g_loss = 0
+
+ if any(t in self.opts.tasks for t in "msd"):
+ m_loss = self.get_masker_loss(multi_domain_batch)
+ self.logger.losses.gen.masker = m_loss.item()
+ g_loss += m_loss
+
+ if "p" in self.opts.tasks and not self.kitti_pretrain:
+ p_loss = self.get_painter_loss(multi_domain_batch)
+ self.logger.losses.gen.painter = p_loss.item()
+ g_loss += p_loss
+
+ assert g_loss != 0 and not isinstance(g_loss, int), "No update in get_G_loss!"
+
+ self.logger.losses.gen.total_loss = g_loss.item()
+
+ return g_loss
+
+ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings
+ """Only update the representation part of the model, meaning everything
+ but the translation part
+
+ * for each batch in available domains:
+ * compute task-specific losses
+ * compute the adaptation and translation decoders' auto-encoding losses
+ * compute the adaptation decoder's translation losses (GAN and Cycle)
+
+ Args:
+ multi_domain_batch (dict): dictionnary mapping domain names to batches from
+ the trainer's loaders
+
+ Returns:
+ torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas
+ """
+ m_loss = 0
+ for domain, batch in multi_domain_batch.items():
+ # We don't care about the flooded domain here
+ if domain == "rf":
+ continue
+
+ x = batch["data"]["x"]
+ z = self.G.encode(x)
+
+ # --------------------------------------
+ # ----- task-specific losses (2) -----
+ # --------------------------------------
+ d_pred = s_pred = z_depth = None
+ for task in ["d", "s", "m"]:
+ if task not in batch["data"]:
+ continue
+
+ target = batch["data"][task]
+
+ if task == "d":
+ loss, d_pred, z_depth = self.masker_d_loss(
+ x, z, target, domain, "G"
+ )
+ m_loss += loss
+ self.logger.losses.gen.task["d"][domain] = loss.item()
+
+ elif task == "s":
+ loss, s_pred = self.masker_s_loss(
+ x, z, d_pred, z_depth, target, domain, "G"
+ )
+ m_loss += loss
+ self.logger.losses.gen.task["s"][domain] = loss.item()
+
+ elif task == "m":
+ cond = None
+ if self.opts.gen.m.use_spade:
+ if not self.opts.gen.m.detach:
+ d_pred = d_pred.clone()
+ s_pred = s_pred.clone()
+ cond = self.G.make_m_cond(d_pred, s_pred, x)
+
+ loss, _ = self.masker_m_loss(
+ x,
+ z,
+ target,
+ domain,
+ "G",
+ cond=cond,
+ z_depth=z_depth,
+ depth_preds=d_pred,
+ )
+ m_loss += loss
+ self.logger.losses.gen.task["m"][domain] = loss.item()
+
+ return m_loss
+
+ def get_painter_loss(self, multi_domain_batch):
+ """Computes the translation loss when flooding/deflooding images
+
+ Args:
+ multi_domain_batch (dict): dictionnary mapping domain names to batches from
+ the trainer's loaders
+
+ Returns:
+ torch.Tensor: scalar loss tensor, weighted according to opts.train.lambdas
+ """
+ step_loss = 0
+ # self.g_opt.zero_grad()
+ lambdas = self.opts.train.lambdas
+ batch_domain = "rf"
+ batch = multi_domain_batch[batch_domain]
+
+ x = batch["data"]["x"]
+ # ! different mask: hides water to be reconstructed
+ # ! 1 for water, 0 otherwise
+ m = batch["data"]["m"]
+ fake_flooded = self.G.paint(m, x)
+
+ # ----------------------
+ # ----- VGG Loss -----
+ # ----------------------
+ if lambdas.G.p.vgg != 0:
+ loss = self.losses["G"]["p"]["vgg"](
+ vgg_preprocess(fake_flooded * m), vgg_preprocess(x * m)
+ )
+ loss *= lambdas.G.p.vgg
+ self.logger.losses.gen.p.vgg = loss.item()
+ step_loss += loss
+
+ # ---------------------
+ # ----- TV Loss -----
+ # ---------------------
+ if lambdas.G.p.tv != 0:
+ loss = self.losses["G"]["p"]["tv"](fake_flooded * m)
+ loss *= lambdas.G.p.tv
+ self.logger.losses.gen.p.tv = loss.item()
+ step_loss += loss
+
+ # --------------------------
+ # ----- Context Loss -----
+ # --------------------------
+ if lambdas.G.p.context != 0:
+ loss = self.losses["G"]["p"]["context"](fake_flooded, x, m)
+ loss *= lambdas.G.p.context
+ self.logger.losses.gen.p.context = loss.item()
+ step_loss += loss
+
+ # ---------------------------------
+ # ----- Reconstruction Loss -----
+ # ---------------------------------
+ if lambdas.G.p.reconstruction != 0:
+ loss = self.losses["G"]["p"]["reconstruction"](fake_flooded, x, m)
+ loss *= lambdas.G.p.reconstruction
+ self.logger.losses.gen.p.reconstruction = loss.item()
+ step_loss += loss
+
+ # -------------------------------------
+ # ----- Local & Global GAN Loss -----
+ # -------------------------------------
+ if self.opts.gen.p.diff_aug.use:
+ fake_flooded = self.diff_transforms(fake_flooded)
+ x = self.diff_transforms(x)
+
+ if self.opts.dis.p.use_local_discriminator:
+ fake_d_global = self.D["p"]["global"](fake_flooded)
+ fake_d_local = self.D["p"]["local"](fake_flooded * m)
+
+ real_d_global = self.D["p"]["global"](x)
+
+ # Note: discriminator returns [out_1,...,out_num_D] outputs
+ # Each out_i is a list [feat1, feat2, ..., pred_i]
+
+ self.logger.losses.gen.p.gan = 0
+
+ loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False)
+ loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False)
+ loss *= lambdas.G["p"]["gan"]
+
+ self.logger.losses.gen.p.gan = loss.item()
+
+ step_loss += loss
+
+ # -----------------------------------
+ # ----- Feature Matching Loss -----
+ # -----------------------------------
+ # (only on global discriminator)
+ # Order must be real, fake
+ if self.opts.dis.p.get_intermediate_features:
+ loss = self.losses["G"]["p"]["featmatch"](real_d_global, fake_d_global)
+ loss *= lambdas.G["p"]["featmatch"]
+
+ if isinstance(loss, float):
+ self.logger.losses.gen.p.featmatch = loss
+ else:
+ self.logger.losses.gen.p.featmatch = loss.item()
+
+ step_loss += loss
+
+ # -------------------------------------------
+ # ----- Single Discriminator GAN Loss -----
+ # -------------------------------------------
+ else:
+ real_cat = torch.cat([m, x], axis=1)
+ fake_cat = torch.cat([m, fake_flooded], axis=1)
+ real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)
+
+ real_fake_d = self.D["p"](real_fake_cat)
+ real_d, fake_d = divide_pred(real_fake_d)
+
+ loss = self.losses["G"]["p"]["gan"](fake_d, True, False)
+ self.logger.losses.gen.p.gan = loss.item()
+ step_loss += loss
+
+ # -----------------------------------
+ # ----- Feature Matching Loss -----
+ # -----------------------------------
+ if self.opts.dis.p.get_intermediate_features and lambdas.G.p.featmatch != 0:
+ loss = self.losses["G"]["p"]["featmatch"](real_d, fake_d)
+ loss *= lambdas.G.p.featmatch
+
+ if isinstance(loss, float):
+ self.logger.losses.gen.p.featmatch = loss
+ else:
+ self.logger.losses.gen.p.featmatch = loss.item()
+
+ step_loss += loss
+
+ return step_loss
+
+ def masker_d_loss(self, x, z, target, domain, for_="G"):
+ assert for_ in {"G", "D"}
+ self.assert_z_matches_x(x, z)
+ assert x.shape[0] == target.shape[0]
+ zero_loss = torch.tensor(0.0, device=self.device)
+ weight = self.opts.train.lambdas.G.d.main
+
+ prediction, z_depth = self.G.decoders["d"](z)
+
+ if self.opts.gen.d.classify.enable:
+ target.squeeze_(1)
+
+ full_loss = self.losses["G"]["tasks"]["d"](prediction, target)
+ full_loss *= weight
+
+ if weight == 0 or (domain == "r" and "d" not in self.pseudo_training_tasks):
+ return zero_loss, prediction, z_depth
+
+ return full_loss, prediction, z_depth
+
+ def masker_s_loss(self, x, z, depth_preds, z_depth, target, domain, for_="G"):
+ assert for_ in {"G", "D"}
+ assert domain in {"r", "s"}
+ self.assert_z_matches_x(x, z)
+ assert x.shape[0] == target.shape[0] if target is not None else True
+ full_loss = torch.tensor(0.0, device=self.device)
+ softmax_preds = None
+ # --------------------------
+ # ----- Segmentation -----
+ # --------------------------
+ pred = None
+ if for_ == "G" or self.opts.gen.s.use_advent:
+ pred = self.G.decoders["s"](z, z_depth)
+
+ # Supervised segmentation loss: crossent for sim domain,
+ # crossent_pseudo for real ; loss is crossent in any case
+ if for_ == "G":
+ if domain == "s" or "s" in self.pseudo_training_tasks:
+ if domain == "s":
+ logger = self.logger.losses.gen.task["s"]["crossent"]
+ weight = self.opts.train.lambdas.G["s"]["crossent"]
+ else:
+ logger = self.logger.losses.gen.task["s"]["crossent_pseudo"]
+ weight = self.opts.train.lambdas.G["s"]["crossent_pseudo"]
+
+ if weight != 0:
+ # Cross-Entropy loss
+ loss_func = self.losses["G"]["tasks"]["s"]["crossent"]
+ loss = loss_func(pred, target.squeeze(1))
+ loss *= weight
+ full_loss += loss
+ logger[domain] = loss.item()
+
+ if domain == "r":
+ weight = self.opts.train.lambdas.G["s"]["minent"]
+ if self.opts.gen.s.use_minent and weight != 0:
+ softmax_preds = softmax(pred, dim=1)
+ # Entropy minimization loss
+ loss = self.losses["G"]["tasks"]["s"]["minent"](softmax_preds)
+ loss *= weight
+ full_loss += loss
+
+ self.logger.losses.gen.task["s"]["minent"]["r"] = loss.item()
+
+ # Fool ADVENT discriminator
+ if self.opts.gen.s.use_advent:
+ if self.opts.gen.s.use_dada and depth_preds is not None:
+ depth_preds = depth_preds.detach()
+ else:
+ depth_preds = None
+
+ if for_ == "D":
+ domain_label = domain
+ logger = {}
+ loss_func = self.losses["D"]["advent"]
+ pred = pred.detach()
+ weight = self.opts.train.lambdas.advent.adv_main
+ else:
+ domain_label = "s"
+ logger = self.logger.losses.gen.task["s"]["advent"]
+ loss_func = self.losses["G"]["tasks"]["s"]["advent"]
+ weight = self.opts.train.lambdas.G["s"]["advent"]
+
+ if (for_ == "D" or domain == "r") and weight != 0:
+ if softmax_preds is None:
+ softmax_preds = softmax(pred, dim=1)
+ loss = loss_func(
+ softmax_preds,
+ self.domain_labels[domain_label],
+ self.D["s"]["Advent"],
+ depth_preds,
+ )
+ loss *= weight
+ full_loss += loss
+ logger[domain] = loss.item()
+
+ if for_ == "D":
+ # WGAN: clipping or GP
+ if self.opts.dis.s.gan_type == "GAN" or "WGAN_norm":
+ pass
+ elif self.opts.dis.s.gan_type == "WGAN":
+ for p in self.D["s"]["Advent"].parameters():
+ p.data.clamp_(
+ self.opts.dis.s.wgan_clamp_lower,
+ self.opts.dis.s.wgan_clamp_upper,
+ )
+ elif self.opts.dis.s.gan_type == "WGAN_gp":
+ prob_need_grad = autograd.Variable(pred, requires_grad=True)
+ d_out = self.D["s"]["Advent"](prob_need_grad)
+ gp = get_WGAN_gradient(prob_need_grad, d_out)
+ gp_loss = gp * self.opts.train.lambdas.advent.WGAN_gp
+ full_loss += gp_loss
+ else:
+ raise NotImplementedError
+
+ return full_loss, pred
+
+ def masker_m_loss(
+ self, x, z, target, domain, for_="G", cond=None, z_depth=None, depth_preds=None
+ ):
+ assert for_ in {"G", "D"}
+ assert domain in {"r", "s"}
+ self.assert_z_matches_x(x, z)
+ assert x.shape[0] == target.shape[0] if target is not None else True
+ full_loss = torch.tensor(0.0, device=self.device)
+
+ pred_logits = self.G.decoders["m"](z, cond=cond, z_depth=z_depth)
+ pred_prob = sigmoid(pred_logits)
+ pred_prob_complementary = 1 - pred_prob
+ prob = torch.cat([pred_prob, pred_prob_complementary], dim=1)
+
+ if for_ == "G":
+ # TV loss
+ weight = self.opts.train.lambdas.G.m.tv
+ if weight != 0:
+ loss = self.losses["G"]["tasks"]["m"]["tv"](pred_prob)
+ loss *= weight
+ full_loss += loss
+
+ self.logger.losses.gen.task["m"]["tv"][domain] = loss.item()
+
+ weight = self.opts.train.lambdas.G.m.bce
+ if domain == "s" and weight != 0:
+ # CrossEnt Loss
+ loss = self.losses["G"]["tasks"]["m"]["bce"](pred_logits, target)
+ loss *= weight
+ full_loss += loss
+ self.logger.losses.gen.task["m"]["bce"]["s"] = loss.item()
+
+ if domain == "r":
+
+ weight = self.opts.train.lambdas.G["m"]["gi"]
+ if self.opts.gen.m.use_ground_intersection and weight != 0:
+ # GroundIntersection loss
+ loss = self.losses["G"]["tasks"]["m"]["gi"](pred_prob, target)
+ loss *= weight
+ full_loss += loss
+ self.logger.losses.gen.task["m"]["gi"]["r"] = loss.item()
+
+ weight = self.opts.train.lambdas.G.m.pl4m
+ if self.use_pl4m and weight != 0:
+ # Painter loss
+ pl4m_loss = self.painter_loss_for_masker(x, pred_prob)
+ pl4m_loss *= weight
+ full_loss += pl4m_loss
+ self.logger.losses.gen.task.m.pl4m.r = pl4m_loss.item()
+
+ weight = self.opts.train.lambdas.advent.ent_main
+ if self.opts.gen.m.use_minent and weight != 0:
+ # MinEnt loss
+ loss = self.losses["G"]["tasks"]["m"]["minent"](prob)
+ loss *= weight
+ full_loss += loss
+ self.logger.losses.gen.task["m"]["minent"]["r"] = loss.item()
+
+ if self.opts.gen.m.use_advent:
+ # AdvEnt loss
+ if self.opts.gen.m.use_dada and depth_preds is not None:
+ depth_preds = depth_preds.detach()
+ depth_preds = torch.nn.functional.interpolate(
+ depth_preds, size=x.shape[-2:], mode="nearest"
+ )
+ else:
+ depth_preds = None
+
+ if for_ == "D":
+ domain_label = domain
+ logger = {}
+ loss_func = self.losses["D"]["advent"]
+ prob = prob.detach()
+ weight = self.opts.train.lambdas.advent.adv_main
+ else:
+ domain_label = "s"
+ logger = self.logger.losses.gen.task["m"]["advent"]
+ loss_func = self.losses["G"]["tasks"]["m"]["advent"]
+ weight = self.opts.train.lambdas.advent.adv_main
+
+ if (for_ == "D" or domain == "r") and weight != 0:
+ loss = loss_func(
+ prob.to(self.device),
+ self.domain_labels[domain_label],
+ self.D["m"]["Advent"],
+ depth_preds,
+ )
+ loss *= weight
+ full_loss += loss
+ logger[domain] = loss.item()
+
+ if for_ == "D":
+ # WGAN: clipping or GP
+ if self.opts.dis.m.gan_type == "GAN" or "WGAN_norm":
+ pass
+ elif self.opts.dis.m.gan_type == "WGAN":
+ for p in self.D["s"]["Advent"].parameters():
+ p.data.clamp_(
+ self.opts.dis.m.wgan_clamp_lower,
+ self.opts.dis.m.wgan_clamp_upper,
+ )
+ elif self.opts.dis.m.gan_type == "WGAN_gp":
+ prob_need_grad = autograd.Variable(prob, requires_grad=True)
+ d_out = self.D["s"]["Advent"](prob_need_grad)
+ gp = get_WGAN_gradient(prob_need_grad, d_out)
+ gp_loss = self.opts.train.lambdas.advent.WGAN_gp * gp
+ full_loss += gp_loss
+ else:
+ raise NotImplementedError
+
+ return full_loss, prob
+
+ def painter_loss_for_masker(self, x, m):
+ # pl4m loss
+ # painter should not be updated
+ for param in self.G.painter.parameters():
+ param.requires_grad = False
+ # TODO for param in self.D.painter.parameters():
+ # param.requires_grad = False
+
+ fake_flooded = self.G.paint(m, x)
+
+ if self.opts.dis.p.use_local_discriminator:
+ fake_d_global = self.D["p"]["global"](fake_flooded)
+ fake_d_local = self.D["p"]["local"](fake_flooded * m)
+
+ # Note: discriminator returns [out_1,...,out_num_D] outputs
+ # Each out_i is a list [feat1, feat2, ..., pred_i]
+
+ pl4m_loss = self.losses["G"]["p"]["gan"](fake_d_global, True, False)
+ pl4m_loss += self.losses["G"]["p"]["gan"](fake_d_local, True, False)
+ else:
+ real_cat = torch.cat([m, x], axis=1)
+ fake_cat = torch.cat([m, fake_flooded], axis=1)
+ real_fake_cat = torch.cat([real_cat, fake_cat], dim=0)
+
+ real_fake_d = self.D["p"](real_fake_cat)
+ _, fake_d = divide_pred(real_fake_d)
+
+ pl4m_loss = self.losses["G"]["p"]["gan"](fake_d, True, False)
+
+ if "p" in self.opts.tasks:
+ for param in self.G.painter.parameters():
+ param.requires_grad = True
+
+ return pl4m_loss
+
+ @torch.no_grad()
+ def run_evaluation(self, verbose=0):
+ print("******************* Running Evaluation ***********************")
+ start_time = time()
+ self.eval_mode()
+ val_logger = None
+ nb_of_batches = None
+ for i, multi_batch_tuple in enumerate(self.val_loaders):
+ # create a dictionnary (domain => batch) from tuple
+ # (batch_domain_0, ..., batch_domain_i)
+ # and send it to self.device
+ nb_of_batches = i + 1
+ multi_domain_batch = {
+ batch["domain"][0]: self.batch_to_device(batch)
+ for batch in multi_batch_tuple
+ }
+ self.get_G_loss(multi_domain_batch, verbose)
+
+ if val_logger is None:
+ val_logger = deepcopy(self.logger.losses.generator)
+ else:
+ val_logger = sum_dict(val_logger, self.logger.losses.generator)
+
+ val_logger = div_dict(val_logger, nb_of_batches)
+ self.logger.losses.generator = val_logger
+ self.logger.log_losses(model_to_update="G", mode="val")
+
+ for d in self.opts.domains:
+ self.logger.log_comet_images("train", d)
+ self.logger.log_comet_images("val", d)
+
+ if "m" in self.opts.tasks and self.has_painter and not self.kitti_pretrain:
+ self.logger.log_comet_combined_images("train", "r")
+ self.logger.log_comet_combined_images("val", "r")
+
+ if self.exp is not None:
+ print()
+
+ if "m" in self.opts.tasks or "s" in self.opts.tasks:
+ self.eval_images("val", "r")
+ self.eval_images("val", "s")
+
+ if "p" in self.opts.tasks and not self.kitti_pretrain:
+ val_fid = compute_val_fid(self)
+ if self.exp is not None:
+ self.exp.log_metric("val_fid", val_fid, step=self.logger.global_step)
+ else:
+ print("Validation FID Score", val_fid)
+
+ self.train_mode()
+ timing = int(time() - start_time)
+ print("****************** Done in {}s *********************".format(timing))
+
+ def eval_images(self, mode, domain):
+ if domain == "s" and self.kitti_pretrain:
+ domain = "kitti"
+ if domain == "rf" or domain not in self.display_images[mode]:
+ return
+
+ metric_funcs = {"accuracy": accuracy, "mIOU": mIOU}
+ metric_avg_scores = {"m": {}}
+ if "s" in self.opts.tasks:
+ metric_avg_scores["s"] = {}
+ if "d" in self.opts.tasks and domain == "s" and self.opts.gen.d.classify.enable:
+ metric_avg_scores["d"] = {}
+
+ for key in metric_funcs:
+ for task in metric_avg_scores:
+ metric_avg_scores[task][key] = []
+
+ for im_set in self.display_images[mode][domain]:
+ x = im_set["data"]["x"].unsqueeze(0).to(self.device)
+ z = self.G.encode(x)
+
+ s_pred = d_pred = z_depth = None
+
+ if "d" in metric_avg_scores:
+ d_pred, z_depth = self.G.decoders["d"](z)
+ d_pred = d_pred.detach().cpu()
+
+ if domain == "s":
+ d = im_set["data"]["d"].unsqueeze(0).detach()
+
+ for metric in metric_funcs:
+ metric_score = metric_funcs[metric](d_pred, d)
+ metric_avg_scores["d"][metric].append(metric_score)
+
+ if "s" in metric_avg_scores:
+ if z_depth is None:
+ if self.opts.gen.s.use_dada and "d" in self.opts.tasks:
+ _, z_depth = self.G.decoders["d"](z)
+ s_pred = self.G.decoders["s"](z, z_depth).detach().cpu()
+ s = im_set["data"]["s"].unsqueeze(0).detach()
+
+ for metric in metric_funcs:
+ metric_score = metric_funcs[metric](s_pred, s)
+ metric_avg_scores["s"][metric].append(metric_score)
+
+ if "m" in self.opts:
+ cond = None
+ if s_pred is not None and d_pred is not None:
+ cond = self.G.make_m_cond(d_pred, s_pred, x)
+ if z_depth is None:
+ if self.opts.gen.m.use_dada and "d" in self.opts.tasks:
+ _, z_depth = self.G.decoders["d"](z)
+
+ pred_mask = (
+ (self.G.mask(z=z, cond=cond, z_depth=z_depth)).detach().cpu()
+ )
+ pred_mask = (pred_mask > 0.5).to(torch.float32)
+ pred_prob = torch.cat([1 - pred_mask, pred_mask], dim=1)
+
+ m = im_set["data"]["m"].unsqueeze(0).detach()
+
+ for metric in metric_funcs:
+ if metric != "mIOU":
+ metric_score = metric_funcs[metric](pred_mask, m)
+ else:
+ metric_score = metric_funcs[metric](pred_prob, m)
+
+ metric_avg_scores["m"][metric].append(metric_score)
+
+ metric_avg_scores = {
+ task: {
+ metric: np.mean(values) if values else float("nan")
+ for metric, values in met_dict.items()
+ }
+ for task, met_dict in metric_avg_scores.items()
+ }
+ metric_avg_scores = {
+ task: {
+ metric: value if not np.isnan(value) else -1
+ for metric, value in met_dict.items()
+ }
+ for task, met_dict in metric_avg_scores.items()
+ }
+ if self.exp is not None:
+ self.exp.log_metrics(
+ flatten_opts(metric_avg_scores),
+ prefix=f"metrics_{mode}_{domain}",
+ step=self.logger.global_step,
+ )
+ else:
+ print(f"metrics_{mode}_{domain}")
+ print(flatten_opts(metric_avg_scores))
+
+ return 0
+
+ def functional_test_mode(self):
+ import atexit
+
+ self.opts.output_path = (
+ Path("~").expanduser() / "climategan" / "functional_tests"
+ )
+ Path(self.opts.output_path).mkdir(parents=True, exist_ok=True)
+ with open(Path(self.opts.output_path) / "is_functional.test", "w") as f:
+ f.write("trainer functional test - delete this dir")
+
+ if self.exp is not None:
+ self.exp.log_parameter("is_functional_test", True)
+ atexit.register(self.del_output_path)
+
+ def del_output_path(self, force=False):
+ import shutil
+
+ if not Path(self.opts.output_path).exists():
+ return
+
+ if (Path(self.opts.output_path) / "is_functional.test").exists() or force:
+ shutil.rmtree(self.opts.output_path)
+
+ def compute_fire(self, x, seg_preds=None, z=None, z_depth=None):
+ """
+ Transforms input tensor given wildfires event
+ Args:
+ x (torch.Tensor): Input tensor
+ seg_preds (torch.Tensor): Semantic segmentation
+ predictions for input tensor
+ z (torch.Tensor): Latent vector of encoded "x".
+ Can be None if seg_preds is given.
+ Returns:
+ torch.Tensor: Wildfire version of input tensor
+ """
+
+ if seg_preds is None:
+ if z is None:
+ z = self.G.encode(x)
+ seg_preds = self.G.decoders["s"](z, z_depth)
+
+ return add_fire(x, seg_preds, self.opts.events.fire)
+
+ def compute_flood(
+ self, x, z=None, z_depth=None, m=None, s=None, cloudy=None, bin_value=-1
+ ):
+ """
+ Applies a flood (mask + paint) to an input image, with optionally
+ pre-computed masker z or mask
+
+ Args:
+ x (torch.Tensor): B x C x H x W -1:1 input image
+ z (torch.Tensor, optional): B x C x H x W Masker latent vector.
+ Defaults to None.
+ m (torch.Tensor, optional): B x 1 x H x W Mask. Defaults to None.
+ bin_value (float, optional): Mask binarization value.
+ Set to -1 to use smooth masks (no binarization)
+
+ Returns:
+ torch.Tensor: B x 3 x H x W -1:1 flooded image
+ """
+
+ if m is None:
+ if z is None:
+ z = self.G.encode(x)
+ if "d" in self.opts.tasks and self.opts.gen.m.use_dada and z_depth is None:
+ _, z_depth = self.G.decoders["d"](z)
+ m = self.G.mask(x=x, z=z, z_depth=z_depth)
+
+ if bin_value >= 0:
+ m = (m > bin_value).to(m.dtype)
+
+ if cloudy:
+ assert s is not None
+ return self.G.paint_cloudy(m, x, s)
+
+ return self.G.paint(m, x)
+
+ def compute_smog(self, x, z=None, d=None, s=None, use_sky_seg=False):
+ # implementation from the paper:
+ # HazeRD: An outdoor scene dataset and benchmark for single image dehazing
+ sky_mask = None
+ if d is None or (use_sky_seg and s is None):
+ if z is None:
+ z = self.G.encode(x)
+ if d is None:
+ d, _ = self.G.decoders["d"](z)
+ if use_sky_seg and s is None:
+ if "s" not in self.opts.tasks:
+ raise ValueError(
+ "Cannot have "
+ + "(use_sky_seg is True and s is None and 's' not in tasks)"
+ )
+ s = self.G.decoders["s"](z)
+ # TODO: s to sky mask
+ # TODO: interpolate to d's size
+
+ params = self.opts.events.smog
+
+ airlight = params.airlight * torch.ones(3)
+ airlight = airlight.view(1, -1, 1, 1).to(self.device)
+
+ irradiance = srgb2lrgb(x)
+
+ beta = torch.tensor([params.beta / params.vr] * 3)
+ beta = beta.view(1, -1, 1, 1).to(self.device)
+
+ d = normalize(d, mini=0.3, maxi=1.0)
+ d = 1.0 / d
+ d = normalize(d, mini=0.1, maxi=1)
+
+ if sky_mask is not None:
+ d[sky_mask] = 1
+
+ d = torch.nn.functional.interpolate(
+ d, size=x.shape[-2:], mode="bilinear", align_corners=True
+ )
+
+ d = d.repeat(1, 3, 1, 1)
+
+ transmission = torch.exp(d * -beta)
+
+ smogged = transmission * irradiance + (1 - transmission) * airlight
+
+ smogged = lrgb2srgb(smogged)
+
+ # add yellow filter
+ alpha = params.alpha / 255
+ yellow_mask = torch.Tensor([params.yellow_color]) / 255
+ yellow_filter = (
+ yellow_mask.unsqueeze(2)
+ .unsqueeze(2)
+ .repeat(1, 1, smogged.shape[-2], smogged.shape[-1])
+ .to(self.device)
+ )
+
+ smogged = smogged * (1 - alpha) + yellow_filter * alpha
+
+ return smogged
diff --git a/climategan/transforms.py b/climategan/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..66b4ab9356e167cc29339300c6996a1f755837aa
--- /dev/null
+++ b/climategan/transforms.py
@@ -0,0 +1,626 @@
+"""Data transforms for the loaders
+"""
+import random
+import traceback
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from skimage.color import rgba2rgb
+from skimage.io import imread
+from torchvision import transforms as trsfs
+from torchvision.transforms.functional import (
+ adjust_brightness,
+ adjust_contrast,
+ adjust_saturation,
+)
+
+from climategan.tutils import normalize
+
+
+def interpolation(task):
+ if task in ["d", "m", "s"]:
+ return {"mode": "nearest"}
+ else:
+ return {"mode": "bilinear", "align_corners": True}
+
+
+class Resize:
+ def __init__(self, target_size, keep_aspect_ratio=False):
+ """
+ Resize transform. Target_size can be an int or a tuple of ints,
+ depending on whether both height and width should have the same
+ final size or not.
+
+ If keep_aspect_ratio is specified then target_size must be an int:
+ the smallest dimension of x will be set to target_size and the largest
+ dimension will be computed to the closest int keeping the original
+ aspect ratio. e.g.
+ >>> x = torch.rand(1, 3, 1200, 1800)
+ >>> m = torch.rand(1, 1, 600, 600)
+ >>> d = {"x": x, "m": m}
+ >>> {k: v.shape for k, v in Resize(640, True)(d).items()}
+ {"x": (1, 3, 640, 960), "m": (1, 1, 640, 960)}
+
+
+
+ Args:
+ target_size (int | tuple(int)): New size for the tensor
+ keep_aspect_ratio (bool, optional): Whether or not to keep aspect ratio
+ when resizing. Requires target_size to be an int. If keeping aspect
+ ratio, smallest dim will be set to target_size. Defaults to False.
+ """
+ if isinstance(target_size, (int, tuple, list)):
+ if not isinstance(target_size, int) and not keep_aspect_ratio:
+ assert len(target_size) == 2
+ self.h, self.w = target_size
+ else:
+ if keep_aspect_ratio:
+ assert isinstance(target_size, int)
+ self.h = self.w = target_size
+
+ self.default_h = int(self.h)
+ self.default_w = int(self.w)
+ self.sizes = {}
+ elif isinstance(target_size, dict):
+ assert (
+ not keep_aspect_ratio
+ ), "dict target_size not compatible with keep_aspect_ratio"
+
+ self.sizes = {
+ k: {"h": v, "w": v} for k, v in target_size.items() if k != "default"
+ }
+ self.default_h = int(target_size["default"])
+ self.default_w = int(target_size["default"])
+
+ self.keep_aspect_ratio = keep_aspect_ratio
+
+ def compute_new_default_size(self, tensor):
+ """
+ compute the new size for a tensor depending on target size
+ and keep_aspect_rato
+
+ Args:
+ tensor (torch.Tensor): 4D tensor N x C x H x W.
+
+ Returns:
+ tuple(int): (new_height, new_width)
+ """
+ if self.keep_aspect_ratio:
+ h, w = tensor.shape[-2:]
+ if h < w:
+ return (self.h, int(self.default_h * w / h))
+ else:
+ return (int(self.default_h * h / w), self.default_w)
+ return (self.default_h, self.default_w)
+
+ def compute_new_size_for_task(self, task):
+ assert (
+ not self.keep_aspect_ratio
+ ), "compute_new_size_for_task is not compatible with keep aspect ratio"
+
+ if task not in self.sizes:
+ return (self.default_h, self.default_w)
+
+ return (self.sizes[task]["h"], self.sizes[task]["w"])
+
+ def __call__(self, data):
+ """
+ Resize a dict of tensors to the "x" key's new_size
+
+ Args:
+ data (dict[str:torch.Tensor]): The data dict to transform
+
+ Returns:
+ dict[str: torch.Tensor]: dict with all tensors resized to the
+ new size of the data["x"] tensor
+ """
+ task = tensor = new_size = None
+ try:
+ if not self.sizes:
+ d = {}
+ new_size = self.compute_new_default_size(
+ data["x"] if "x" in data else list(data.values())[0]
+ )
+ for task, tensor in data.items():
+ d[task] = F.interpolate(
+ tensor, size=new_size, **interpolation(task)
+ )
+ return d
+
+ d = {}
+ for task, tensor in data.items():
+ new_size = self.compute_new_size_for_task(task)
+ d[task] = F.interpolate(tensor, size=new_size, **interpolation(task))
+ return d
+
+ except Exception as e:
+ tb = traceback.format_exc()
+ print("Debug: task, shape, interpolation, h, w, new_size")
+ print(task)
+ print(tensor.shape)
+ print(interpolation(task))
+ print(self.h, self.w)
+ print(new_size)
+ print(tb)
+ raise Exception(e)
+
+
+class RandomCrop:
+ def __init__(self, size, center=False):
+ assert isinstance(size, (int, tuple, list))
+ if not isinstance(size, int):
+ assert len(size) == 2
+ self.h, self.w = size
+ else:
+ self.h = self.w = size
+
+ self.h = int(self.h)
+ self.w = int(self.w)
+ self.center = center
+
+ def __call__(self, data):
+ H, W = (
+ data["x"].size()[-2:] if "x" in data else list(data.values())[0].size()[-2:]
+ )
+
+ if not self.center:
+ top = np.random.randint(0, H - self.h)
+ left = np.random.randint(0, W - self.w)
+ else:
+ top = (H - self.h) // 2
+ left = (W - self.w) // 2
+
+ return {
+ task: tensor[:, :, top : top + self.h, left : left + self.w]
+ for task, tensor in data.items()
+ }
+
+
+class RandomHorizontalFlip:
+ def __init__(self, p=0.5):
+ # self.flip = TF.hflip
+ self.p = p
+
+ def __call__(self, data):
+ if np.random.rand() > self.p:
+ return data
+ return {task: torch.flip(tensor, [3]) for task, tensor in data.items()}
+
+
+class ToTensor:
+ def __init__(self):
+ self.ImagetoTensor = trsfs.ToTensor()
+ self.MaptoTensor = self.ImagetoTensor
+
+ def __call__(self, data):
+ new_data = {}
+ for task, im in data.items():
+ if task in {"x", "a"}:
+ new_data[task] = self.ImagetoTensor(im)
+ elif task in {"m"}:
+ new_data[task] = self.MaptoTensor(im)
+ elif task == "s":
+ new_data[task] = torch.squeeze(torch.from_numpy(np.array(im))).to(
+ torch.int64
+ )
+ elif task == "d":
+ new_data = im
+
+ return new_data
+
+
+class Normalize:
+ def __init__(self, opts):
+ if opts.data.normalization == "HRNet":
+ self.normImage = trsfs.Normalize(
+ ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
+ )
+ else:
+ self.normImage = trsfs.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
+ self.normDepth = lambda x: x
+ self.normMask = lambda x: x
+ self.normSeg = lambda x: x
+
+ self.normalize = {
+ "x": self.normImage,
+ "s": self.normSeg,
+ "d": self.normDepth,
+ "m": self.normMask,
+ }
+
+ def __call__(self, data):
+ return {
+ task: self.normalize.get(task, lambda x: x)(tensor.squeeze(0))
+ for task, tensor in data.items()
+ }
+
+
+class RandBrightness: # Input need to be between -1 and 1
+ def __call__(self, data):
+ return {
+ task: rand_brightness(tensor) if task == "x" else tensor
+ for task, tensor in data.items()
+ }
+
+
+class RandSaturation:
+ def __call__(self, data):
+ return {
+ task: rand_saturation(tensor) if task == "x" else tensor
+ for task, tensor in data.items()
+ }
+
+
+class RandContrast:
+ def __call__(self, data):
+ return {
+ task: rand_contrast(tensor) if task == "x" else tensor
+ for task, tensor in data.items()
+ }
+
+
+class BucketizeDepth:
+ def __init__(self, opts, domain):
+ self.domain = domain
+
+ if opts.gen.d.classify.enable and domain in {"s", "kitti"}:
+ self.buckets = torch.linspace(
+ *[
+ opts.gen.d.classify.linspace.min,
+ opts.gen.d.classify.linspace.max,
+ opts.gen.d.classify.linspace.buckets - 1,
+ ]
+ )
+
+ self.transforms = {
+ "d": lambda tensor: torch.bucketize(
+ tensor, self.buckets, out_int32=True, right=True
+ )
+ }
+ else:
+ self.transforms = {}
+
+ def __call__(self, data):
+ return {
+ task: self.transforms.get(task, lambda x: x)(tensor)
+ for task, tensor in data.items()
+ }
+
+
+class PrepareInference:
+ """
+ Transform which:
+ - transforms a str or an array into a tensor
+ - resizes the image to keep the aspect ratio
+ - crops in the center of the resized image
+ - normalize to 0:1
+ - rescale to -1:1
+ """
+
+ def __init__(self, target_size=640, half=False, is_label=False, enforce_128=True):
+ if enforce_128:
+ if target_size % 2 ** 7 != 0:
+ raise ValueError(
+ f"Received a target_size of {target_size}, which is not a "
+ + "multiple of 2^7 = 128. Set enforce_128 to False to disable "
+ + "this error."
+ )
+ self.resize = Resize(target_size, keep_aspect_ratio=True)
+ self.crop = RandomCrop((target_size, target_size), center=True)
+ self.half = half
+ self.is_label = is_label
+
+ def process(self, t):
+ if isinstance(t, (str, Path)):
+ t = imread(str(t))
+
+ if isinstance(t, np.ndarray):
+ if t.shape[-1] == 4:
+ t = rgba2rgb(t)
+
+ t = torch.from_numpy(t)
+ if t.ndim == 3:
+ t = t.permute(2, 0, 1)
+
+ if t.ndim == 3:
+ t = t.unsqueeze(0)
+ elif t.ndim == 2:
+ t = t.unsqueeze(0).unsqueeze(0)
+
+ if not self.is_label:
+ t = t.to(torch.float32)
+ t = normalize(t)
+ t = (t - 0.5) * 2
+
+ t = {"m": t} if self.is_label else {"x": t}
+ t = self.resize(t)
+ t = self.crop(t)
+ t = t["m"] if self.is_label else t["x"]
+
+ if self.half and not self.is_label:
+ t = t.half()
+
+ return t
+
+ def __call__(self, x):
+ """
+ normalize, rescale, resize, crop in the center
+
+ x can be: dict {"task": data} list [data, ..] or data
+ data ^ can be a str, a Path, a numpy arrray or a Tensor
+ """
+ if isinstance(x, dict):
+ return {k: self.process(v) for k, v in x.items()}
+
+ if isinstance(x, list):
+ return [self.process(t) for t in x]
+
+ return self.process(x)
+
+
+class PrepareTest:
+ """
+ Transform which:
+ - transforms a str or an array into a tensor
+ - resizes the image to keep the aspect ratio
+ - crops in the center of the resized image
+ - normalize to 0:1 (optional)
+ - rescale to -1:1 (optional)
+ """
+
+ def __init__(self, target_size=640, half=False):
+ self.resize = Resize(target_size, keep_aspect_ratio=True)
+ self.crop = RandomCrop((target_size, target_size), center=True)
+ self.half = half
+
+ def process(self, t, normalize=False, rescale=False):
+ if isinstance(t, (str, Path)):
+ # t = img_as_float(imread(str(t)))
+ t = imread(str(t))
+ if t.shape[-1] == 4:
+ # t = rgba2rgb(t)
+ t = t[:, :, :3]
+ if np.ndim(t) == 2:
+ t = np.repeat(t[:, :, np.newaxis], 3, axis=2)
+
+ if isinstance(t, np.ndarray):
+ t = torch.from_numpy(t)
+ t = t.permute(2, 0, 1)
+
+ if len(t.shape) == 3:
+ t = t.unsqueeze(0)
+
+ t = t.to(torch.float32)
+ normalize(t) if normalize else t
+ (t - 0.5) * 2 if rescale else t
+ t = {"x": t}
+ t = self.resize(t)
+ t = self.crop(t)
+ t = t["x"]
+
+ if self.half:
+ return t.to(torch.float16)
+
+ return t
+
+ def __call__(self, x, normalize=False, rescale=False):
+ """
+ Call process()
+
+ x can be: dict {"task": data} list [data, ..] or data
+ data ^ can be a str, a Path, a numpy arrray or a Tensor
+ """
+ if isinstance(x, dict):
+ return {k: self.process(v, normalize, rescale) for k, v in x.items()}
+
+ if isinstance(x, list):
+ return [self.process(t, normalize, rescale) for t in x]
+
+ return self.process(x, normalize, rescale)
+
+
+def get_transform(transform_item, mode):
+ """Returns the torchivion transform function associated to a
+ transform_item listed in opts.data.transforms ; transform_item is
+ an addict.Dict
+ """
+
+ if transform_item.name == "crop" and not (
+ transform_item.ignore is True or transform_item.ignore == mode
+ ):
+ return RandomCrop(
+ (transform_item.height, transform_item.width),
+ center=transform_item.center == mode,
+ )
+
+ elif transform_item.name == "resize" and not (
+ transform_item.ignore is True or transform_item.ignore == mode
+ ):
+ return Resize(
+ transform_item.new_size, transform_item.get("keep_aspect_ratio", False)
+ )
+
+ elif transform_item.name == "hflip" and not (
+ transform_item.ignore is True or transform_item.ignore == mode
+ ):
+ return RandomHorizontalFlip(p=transform_item.p or 0.5)
+
+ elif transform_item.name == "brightness" and not (
+ transform_item.ignore is True or transform_item.ignore == mode
+ ):
+ return RandBrightness()
+
+ elif transform_item.name == "saturation" and not (
+ transform_item.ignore is True or transform_item.ignore == mode
+ ):
+ return RandSaturation()
+
+ elif transform_item.name == "contrast" and not (
+ transform_item.ignore is True or transform_item.ignore == mode
+ ):
+ return RandContrast()
+
+ elif transform_item.ignore is True or transform_item.ignore == mode:
+ return None
+
+ raise ValueError("Unknown transform_item {}".format(transform_item))
+
+
+def get_transforms(opts, mode, domain):
+ """Get all the transform functions listed in opts.data.transforms
+ using get_transform(transform_item, mode)
+ """
+ transforms = []
+ color_jittering_transforms = ["brightness", "saturation", "contrast"]
+
+ for t in opts.data.transforms:
+ if t.name not in color_jittering_transforms:
+ transforms.append(get_transform(t, mode))
+
+ if "p" not in opts.tasks and mode == "train":
+ for t in opts.data.transforms:
+ if t.name in color_jittering_transforms:
+ transforms.append(get_transform(t, mode))
+
+ transforms += [Normalize(opts), BucketizeDepth(opts, domain)]
+ transforms = [t for t in transforms if t is not None]
+
+ return transforms
+
+
+# ----- Adapted functions from https://github.com/mit-han-lab/data-efficient-gans -----#
+def rand_brightness(tensor, is_diff_augment=False):
+ if is_diff_augment:
+ assert len(tensor.shape) == 4
+ type_ = tensor.dtype
+ device_ = tensor.device
+ rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
+ return tensor + (rand_tens - 0.5)
+ else:
+ factor = random.uniform(0.5, 1.5)
+ tensor = adjust_brightness(tensor, brightness_factor=factor)
+ # dummy pixels to fool scaling and preserve range
+ tensor[:, :, 0, 0] = 1.0
+ tensor[:, :, -1, -1] = 0.0
+ return tensor
+
+
+def rand_saturation(tensor, is_diff_augment=False):
+ if is_diff_augment:
+ assert len(tensor.shape) == 4
+ type_ = tensor.dtype
+ device_ = tensor.device
+ rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
+ x_mean = tensor.mean(dim=1, keepdim=True)
+ return (tensor - x_mean) * (rand_tens * 2) + x_mean
+ else:
+ factor = random.uniform(0.5, 1.5)
+ tensor = adjust_saturation(tensor, saturation_factor=factor)
+ # dummy pixels to fool scaling and preserve range
+ tensor[:, :, 0, 0] = 1.0
+ tensor[:, :, -1, -1] = 0.0
+ return tensor
+
+
+def rand_contrast(tensor, is_diff_augment=False):
+ if is_diff_augment:
+ assert len(tensor.shape) == 4
+ type_ = tensor.dtype
+ device_ = tensor.device
+ rand_tens = torch.rand(tensor.size(0), 1, 1, 1, dtype=type_, device=device_)
+ x_mean = tensor.mean(dim=[1, 2, 3], keepdim=True)
+ return (tensor - x_mean) * (rand_tens + 0.5) + x_mean
+ else:
+ factor = random.uniform(0.5, 1.5)
+ tensor = adjust_contrast(tensor, contrast_factor=factor)
+ # dummy pixels to fool scaling and preserve range
+ tensor[:, :, 0, 0] = 1.0
+ tensor[:, :, -1, -1] = 0.0
+ return tensor
+
+
+def rand_cutout(tensor, ratio=0.5):
+ assert len(tensor.shape) == 4, "For rand cutout, tensor must be 4D."
+ type_ = tensor.dtype
+ device_ = tensor.device
+ cutout_size = int(tensor.size(-2) * ratio + 0.5), int(tensor.size(-1) * ratio + 0.5)
+ grid_batch, grid_x, grid_y = torch.meshgrid(
+ torch.arange(tensor.size(0), dtype=torch.long, device=device_),
+ torch.arange(cutout_size[0], dtype=torch.long, device=device_),
+ torch.arange(cutout_size[1], dtype=torch.long, device=device_),
+ )
+ size_ = [tensor.size(0), 1, 1]
+ offset_x = torch.randint(
+ 0,
+ tensor.size(-2) + (1 - cutout_size[0] % 2),
+ size=size_,
+ device=device_,
+ )
+ offset_y = torch.randint(
+ 0,
+ tensor.size(-1) + (1 - cutout_size[1] % 2),
+ size=size_,
+ device=device_,
+ )
+ grid_x = torch.clamp(
+ grid_x + offset_x - cutout_size[0] // 2, min=0, max=tensor.size(-2) - 1
+ )
+ grid_y = torch.clamp(
+ grid_y + offset_y - cutout_size[1] // 2, min=0, max=tensor.size(-1) - 1
+ )
+ mask = torch.ones(
+ tensor.size(0), tensor.size(2), tensor.size(3), dtype=type_, device=device_
+ )
+ mask[grid_batch, grid_x, grid_y] = 0
+ return tensor * mask.unsqueeze(1)
+
+
+def rand_translation(tensor, ratio=0.125):
+ assert len(tensor.shape) == 4, "For rand translation, tensor must be 4D."
+ device_ = tensor.device
+ shift_x, shift_y = (
+ int(tensor.size(2) * ratio + 0.5),
+ int(tensor.size(3) * ratio + 0.5),
+ )
+ translation_x = torch.randint(
+ -shift_x, shift_x + 1, size=[tensor.size(0), 1, 1], device=device_
+ )
+ translation_y = torch.randint(
+ -shift_y, shift_y + 1, size=[tensor.size(0), 1, 1], device=device_
+ )
+ grid_batch, grid_x, grid_y = torch.meshgrid(
+ torch.arange(tensor.size(0), dtype=torch.long, device=device_),
+ torch.arange(tensor.size(2), dtype=torch.long, device=device_),
+ torch.arange(tensor.size(3), dtype=torch.long, device=device_),
+ )
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, tensor.size(2) + 1)
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, tensor.size(3) + 1)
+ x_pad = F.pad(tensor, [1, 1, 1, 1, 0, 0, 0, 0])
+ tensor = (
+ x_pad.permute(0, 2, 3, 1)
+ .contiguous()[grid_batch, grid_x, grid_y]
+ .permute(0, 3, 1, 2)
+ )
+ return tensor
+
+
+class DiffTransforms:
+ def __init__(self, diff_aug_opts):
+ self.do_color_jittering = diff_aug_opts.do_color_jittering
+ self.do_cutout = diff_aug_opts.do_cutout
+ self.do_translation = diff_aug_opts.do_translation
+ self.cutout_ratio = diff_aug_opts.cutout_ratio
+ self.translation_ratio = diff_aug_opts.translation_ratio
+
+ def __call__(self, tensor):
+ if self.do_color_jittering:
+ tensor = rand_brightness(tensor, is_diff_augment=True)
+ tensor = rand_contrast(tensor, is_diff_augment=True)
+ tensor = rand_saturation(tensor, is_diff_augment=True)
+ if self.do_translation:
+ tensor = rand_translation(tensor, ratio=self.translation_ratio)
+ if self.do_cutout:
+ tensor = rand_cutout(tensor, ratio=self.cutout_ratio)
+ return tensor
diff --git a/climategan/tutils.py b/climategan/tutils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cdaee9d081bb3010d21570b0d38fc7814595937
--- /dev/null
+++ b/climategan/tutils.py
@@ -0,0 +1,721 @@
+"""Tensor-utils
+"""
+import io
+import math
+from contextlib import redirect_stdout
+from pathlib import Path
+
+# from copy import copy
+from threading import Thread
+
+import numpy as np
+import torch
+import torch.nn as nn
+from skimage import io as skio
+from torch import autograd
+from torch.autograd import Variable
+from torch.nn import init
+
+from climategan.utils import all_texts_to_array
+
+
+def transforms_string(ts):
+ return " -> ".join([t.__class__.__name__ for t in ts.transforms])
+
+
+def init_weights(net, init_type="normal", init_gain=0.02, verbose=0, caller=""):
+ """Initialize network weights.
+ Parameters:
+ net (network) -- network to be initialized
+ init_type (str) -- the name of an initialization method:
+ normal | xavier | kaiming | orthogonal
+ init_gain (float) -- scaling factor for normal, xavier and orthogonal.
+
+ We use 'normal' in the original pix2pix and CycleGAN paper.
+ But xavier and kaiming might work better for some applications.
+ Feel free to try yourself.
+ """
+
+ if not init_type:
+ print(
+ "init_weights({}): init_type is {}, defaulting to normal".format(
+ caller + " " + net.__class__.__name__, init_type
+ )
+ )
+ init_type = "normal"
+ if not init_gain:
+ print(
+ "init_weights({}): init_gain is {}, defaulting to normal".format(
+ caller + " " + net.__class__.__name__, init_type
+ )
+ )
+ init_gain = 0.02
+
+ def init_func(m):
+ classname = m.__class__.__name__
+ if classname.find("BatchNorm2d") != -1:
+ if hasattr(m, "weight") and m.weight is not None:
+ init.normal_(m.weight.data, 1.0, init_gain)
+ if hasattr(m, "bias") and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+ elif hasattr(m, "weight") and (
+ classname.find("Conv") != -1 or classname.find("Linear") != -1
+ ):
+ if init_type == "normal":
+ init.normal_(m.weight.data, 0.0, init_gain)
+ elif init_type == "xavier":
+ init.xavier_normal_(m.weight.data, gain=init_gain)
+ elif init_type == "xavier_uniform":
+ init.xavier_uniform_(m.weight.data, gain=1.0)
+ elif init_type == "kaiming":
+ init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
+ elif init_type == "orthogonal":
+ init.orthogonal_(m.weight.data, gain=init_gain)
+ elif init_type == "none": # uses pytorch's default init method
+ m.reset_parameters()
+ else:
+ raise NotImplementedError(
+ "initialization method [%s] is not implemented" % init_type
+ )
+ if hasattr(m, "bias") and m.bias is not None:
+ init.constant_(m.bias.data, 0.0)
+
+ if verbose > 0:
+ print("initialize %s with %s" % (net.__class__.__name__, init_type))
+ net.apply(init_func)
+
+
+def domains_to_class_tensor(domains, one_hot=False):
+ """Converts a list of strings to a 1D Tensor representing the domains
+
+ domains_to_class_tensor(["sf", "rn"])
+ >>> torch.Tensor([2, 1])
+
+ Args:
+ domain (list(str)): each element of the list should be in {rf, rn, sf, sn}
+ one_hot (bool, optional): whether or not to 1-h encode class labels.
+ Defaults to False.
+ Raises:
+ ValueError: One of the domains listed is not in {rf, rn, sf, sn}
+
+ Returns:
+ torch.Tensor: 1D tensor mapping a domain to an int (not 1-hot) or 1-hot
+ domain labels in a 2D tensor
+ """
+
+ mapping = {"r": 0, "s": 1}
+
+ if not all(domain in mapping for domain in domains):
+ raise ValueError(
+ "Unknown domains {} should be in {}".format(domains, list(mapping.keys()))
+ )
+
+ target = torch.tensor([mapping[domain] for domain in domains])
+
+ if one_hot:
+ one_hot_target = torch.FloatTensor(len(target), 2) # 2 domains
+ one_hot_target.zero_()
+ one_hot_target.scatter_(1, target.unsqueeze(1), 1)
+ # https://discuss.pytorch.org/t/convert-int-into-one-hot-format/507
+ target = one_hot_target
+ return target
+
+
+def fake_domains_to_class_tensor(domains, one_hot=False):
+ """Converts a list of strings to a 1D Tensor representing the fake domains
+ (real or sim only)
+
+ fake_domains_to_class_tensor(["s", "r"], False)
+ >>> torch.Tensor([0, 2])
+
+
+ Args:
+ domain (list(str)): each element of the list should be in {r, s}
+ one_hot (bool, optional): whether or not to 1-h encode class labels.
+ Defaults to False.
+ Raises:
+ ValueError: One of the domains listed is not in {rf, rn, sf, sn}
+
+ Returns:
+ torch.Tensor: 1D tensor mapping a domain to an int (not 1-hot) or
+ a 2D tensor filled with 0.25 to fool the classifier (equiprobability
+ for each domain).
+ """
+ if one_hot:
+ target = torch.FloatTensor(len(domains), 2)
+ target.fill_(0.5)
+
+ else:
+ mapping = {"r": 1, "s": 0}
+
+ if not all(domain in mapping for domain in domains):
+ raise ValueError(
+ "Unknown domains {} should be in {}".format(
+ domains, list(mapping.keys())
+ )
+ )
+
+ target = torch.tensor([mapping[domain] for domain in domains])
+ return target
+
+
+def show_tanh_tensor(tensor):
+ import skimage
+
+ if isinstance(tensor, torch.Tensor):
+ image = tensor.permute(1, 2, 0).detach().numpy()
+ else:
+ image = tensor
+ if image.shape[-1] != 3:
+ image = image.transpose(1, 2, 0)
+
+ if image.min() < 0 and image.min() > -1:
+ image = image / 2 + 0.5
+ elif image.min() < -1:
+ raise ValueError("can't handle this data")
+
+ skimage.io.imshow(image)
+
+
+def normalize_tensor(t):
+ """
+ Brings any tensor to the [0; 1] range.
+
+ Args:
+ t (torch.Tensor): input to normalize
+
+ Returns:
+ torch.Tensor: t projected to [0; 1]
+ """
+ t = t - torch.min(t)
+ t = t / torch.max(t)
+ return t
+
+
+def get_normalized_depth_t(tensor, domain, normalize=False, log=True):
+ assert not (normalize and log)
+ if domain == "r":
+ # megadepth depth
+ tensor = tensor.unsqueeze(0)
+ tensor = tensor - torch.min(tensor)
+ tensor = torch.true_divide(tensor, torch.max(tensor))
+
+ elif domain == "s":
+ # from 3-channel depth encoding from Unity simulator to 1-channel [0-1] values
+ tensor = decode_unity_depth_t(tensor, log=log, normalize=normalize)
+
+ elif domain == "kitti":
+ tensor = tensor / 100
+ if not log:
+ tensor = 1 / tensor
+ if normalize:
+ tensor = tensor - tensor.min()
+ tensor = tensor / tensor.max()
+ else:
+ tensor = torch.log(tensor)
+
+ tensor = tensor.unsqueeze(0)
+
+ return tensor
+
+
+def decode_bucketed_depth(tensor, opts):
+ # tensor is size 1 x C x H x W
+ assert tensor.shape[0] == 1
+ idx = torch.argmax(tensor.squeeze(0), dim=0) # channels become dim 0 with squeeze
+ linspace_args = (
+ opts.gen.d.classify.linspace.min,
+ opts.gen.d.classify.linspace.max,
+ opts.gen.d.classify.linspace.buckets,
+ )
+ indexer = torch.linspace(*linspace_args)
+ log_depth = indexer[idx.long()].to(torch.float32) # H x W
+ depth = torch.exp(log_depth)
+ return depth.unsqueeze(0).unsqueeze(0).to(tensor.device)
+
+
+def decode_unity_depth_t(unity_depth, log=True, normalize=False, numpy=False, far=1000):
+ """Transforms the 3-channel encoded depth map from our Unity simulator
+ to 1-channel depth map containing metric depth values.
+ The depth is encoded in the following way:
+ - The information from the simulator is (1 - LinearDepth (in [0,1])).
+ far corresponds to the furthest distance to the camera included in the
+ depth map.
+ LinearDepth * far gives the real metric distance to the camera.
+ - depth is first divided in 31 slices encoded in R channel with values ranging
+ from 0 to 247
+ - each slice is divided again in 31 slices, whose value is encoded in G channel
+ - each of the G slices is divided into 256 slices, encoded in B channel
+
+ In total, we have a discretization of depth into N = 31*31*256 - 1 possible values,
+ covering a range of far/N meters.
+
+ Note that, what we encode here is 1 - LinearDepth so that the furthest point is
+ [0,0,0] (that is sky) and the closest point[255,255,255]
+
+ The metric distance associated to a pixel whose depth is (R,G,B) is :
+ d = (far/N) * [((255 - R)//8)*256*31 + ((255 - G)//8)*256 + (255 - B)]
+
+ * torch.Tensor in [0, 1] as torch.float32 if numpy == False
+
+ * else numpy.array in [0, 255] as np.uint8
+
+ Args:
+ unity_depth (torch.Tensor): one depth map obtained from our simulator
+ numpy (bool, optional): Whether to return a float tensor or an int array.
+ Defaults to False.
+ far: far parameter of the camera in Unity simulator.
+
+ Returns:
+ [torch.Tensor or numpy.array]: decoded depth
+ """
+ R = unity_depth[:, :, 0]
+ G = unity_depth[:, :, 1]
+ B = unity_depth[:, :, 2]
+
+ R = ((247 - R) / 8).type(torch.IntTensor)
+ G = ((247 - G) / 8).type(torch.IntTensor)
+ B = (255 - B).type(torch.IntTensor)
+ depth = ((R * 256 * 31 + G * 256 + B).type(torch.FloatTensor)) / (256 * 31 * 31 - 1)
+ depth = depth * far
+ if not log:
+ depth = 1 / depth
+ depth = depth.unsqueeze(0) # (depth * far).unsqueeze(0)
+
+ if log:
+ depth = torch.log(depth)
+ if normalize:
+ depth = depth - torch.min(depth)
+ depth /= torch.max(depth)
+ if numpy:
+ depth = depth.data.cpu().numpy()
+ return depth.astype(np.uint8).squeeze()
+ return depth
+
+
+def to_inv_depth(log_depth, numpy=False):
+ """Convert log depth tensor to inverse depth image for display
+
+ Args:
+ depth (Tensor): log depth float tensor
+ """
+ depth = torch.exp(log_depth)
+ # visualize prediction using inverse depth, so that we don't need sky
+ # segmentation (if you want to use RGB map for visualization,
+ # you have to run semantic segmentation to mask the sky first
+ # since the depth of sky is random from CNN)
+ inv_depth = 1 / depth
+ inv_depth /= torch.max(inv_depth)
+ if numpy:
+ inv_depth = inv_depth.data.cpu().numpy()
+ # you might also use percentile for better visualization
+
+ return inv_depth
+
+
+def shuffle_batch_tuple(mbt):
+ """shuffle the order of domains in the batch
+
+ Args:
+ mbt (tuple): multi-batch tuple
+
+ Returns:
+ list: randomized list of domain-specific batches
+ """
+ assert isinstance(mbt, (tuple, list))
+ assert len(mbt) > 0
+ perm = np.random.permutation(len(mbt))
+ return [mbt[i] for i in perm]
+
+
+def slice_batch(batch, slice_size):
+ assert slice_size > 0
+ for k, v in batch.items():
+ if isinstance(v, dict):
+ for task, d in v.items():
+ batch[k][task] = d[:slice_size]
+ else:
+ batch[k] = v[:slice_size]
+ return batch
+
+
+def save_tanh_tensor(image, path):
+ """Save an image which can be numpy or tensor, 2 or 3 dims (no batch)
+ to path.
+
+ Args:
+ image (np.array or torch.Tensor): image to save
+ path (pathlib.Path or str): where to save the image
+ """
+ path = Path(path)
+ if isinstance(image, torch.Tensor):
+ image = image.detach().cpu().numpy()
+ if image.shape[-1] != 3 and image.shape[0] == 3:
+ image = np.transpose(image, (1, 2, 0))
+ if image.min() < 0 and image.min() > -1:
+ image = image / 2 + 0.5
+ elif image.min() < -1:
+ image -= image.min()
+ image /= image.max()
+ # print("Warning: scaling image data in save_tanh_tensor")
+
+ skio.imsave(path, (image * 255).astype(np.uint8))
+
+
+def save_batch(multi_domain_batch, root="./", step=0, num_threads=5):
+ root = Path(root)
+ root.mkdir(parents=True, exist_ok=True)
+ images_to_save = {"paths": [], "images": []}
+ for domain, batch in multi_domain_batch.items():
+ y = batch["data"].get("y")
+ x = batch["data"]["x"]
+ if y is not None:
+ paths = batch["paths"]["x"]
+ imtensor = torch.cat([x, y], dim=-1)
+ for i, im in enumerate(imtensor):
+ imid = Path(paths[i]).stem[:10]
+ images_to_save["paths"] += [
+ root / "im_{}_{}_{}.png".format(step, domain, imid)
+ ]
+ images_to_save["images"].append(im)
+ if num_threads > 0:
+ threaded_write(images_to_save["images"], images_to_save["paths"], num_threads)
+ else:
+ for im, path in zip(images_to_save["images"], images_to_save["paths"]):
+ save_tanh_tensor(im, path)
+
+
+def threaded_write(images, paths, num_threads=5):
+ t_im = []
+ t_p = []
+ for im, p in zip(images, paths):
+ t_im.append(im)
+ t_p.append(p)
+ if len(t_im) == num_threads:
+ ts = [
+ Thread(target=save_tanh_tensor, args=(_i, _p))
+ for _i, _p in zip(t_im, t_p)
+ ]
+ list(map(lambda t: t.start(), ts))
+ list(map(lambda t: t.join(), ts))
+ t_im = []
+ t_p = []
+ if t_im:
+ ts = [
+ Thread(target=save_tanh_tensor, args=(_i, _p)) for _i, _p in zip(t_im, t_p)
+ ]
+ list(map(lambda t: t.start(), ts))
+ list(map(lambda t: t.join(), ts))
+
+
+def get_num_params(model):
+ total_params = sum(p.numel() for p in model.parameters())
+ return total_params
+
+
+def vgg_preprocess(batch):
+ """Preprocess batch to use VGG model"""
+ tensortype = type(batch.data)
+ (r, g, b) = torch.chunk(batch, 3, dim=1)
+ batch = torch.cat((b, g, r), dim=1) # convert RGB to BGR
+ batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255]
+ mean = tensortype(batch.data.size()).cuda()
+ mean[:, 0, :, :] = 103.939
+ mean[:, 1, :, :] = 116.779
+ mean[:, 2, :, :] = 123.680
+ batch = batch.sub(Variable(mean)) # subtract mean
+ return batch
+
+
+def zero_grad(model: nn.Module):
+ """
+ Sets gradients to None. Mode efficient than model.zero_grad()
+ or opt.zero_grad() according to https://www.youtube.com/watch?v=9mS1fIYj1So
+
+ Args:
+ model (nn.Module): model to zero out
+ """
+ for p in model.parameters():
+ p.grad = None
+
+
+# Take the prediction of fake and real images from the combined batch
+def divide_pred(disc_output):
+ """
+ Divide a multiscale discriminator's output into 2 sets of tensors,
+ expecting the input to the discriminator to be a concatenation
+ on the batch axis of real and fake (or fake and real) images,
+ effectively doubling the batch size for better batchnorm statistics
+
+ Args:
+ disc_output (list | torch.Tensor): Discriminator output to split
+
+ Returns:
+ list | torch.Tensor[type]: pair of split outputs
+ """
+ # https://github.com/NVlabs/SPADE/blob/master/models/pix2pix_model.py
+ # the prediction contains the intermediate outputs of multiscale GAN,
+ # so it's usually a list
+ if type(disc_output) == list:
+ half1 = []
+ half2 = []
+ for p in disc_output:
+ half1.append([tensor[: tensor.size(0) // 2] for tensor in p])
+ half2.append([tensor[tensor.size(0) // 2 :] for tensor in p])
+ else:
+ half1 = disc_output[: disc_output.size(0) // 2]
+ half2 = disc_output[disc_output.size(0) // 2 :]
+
+ return half1, half2
+
+
+def is_tpu_available():
+ _torch_tpu_available = False
+ try:
+ import torch_xla.core.xla_model as xm # type: ignore
+
+ if "xla" in str(xm.xla_device()):
+ _torch_tpu_available = True
+ else:
+ _torch_tpu_available = False
+ except ImportError:
+ _torch_tpu_available = False
+
+ return _torch_tpu_available
+
+
+def get_WGAN_gradient(input, output):
+ # github code reference:
+ # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py
+ # Calculate the gradient that WGAN-gp needs
+ grads = autograd.grad(
+ outputs=output,
+ inputs=input,
+ grad_outputs=torch.ones(output.size()).cuda(),
+ create_graph=True,
+ retain_graph=True,
+ only_inputs=True,
+ )[0]
+ grads = grads.view(grads.size(0), -1)
+ gp = ((grads.norm(2, dim=1) - 1) ** 2).mean()
+ return gp
+
+
+def print_num_parameters(trainer, force=False):
+ if trainer.verbose == 0 and not force:
+ return
+ print("-" * 35)
+ if trainer.G.encoder is not None:
+ print(
+ "{:21}:".format("num params encoder"),
+ f"{get_num_params(trainer.G.encoder):12,}",
+ )
+ for d in trainer.G.decoders.keys():
+ print(
+ "{:21}:".format(f"num params decoder {d}"),
+ f"{get_num_params(trainer.G.decoders[d]):12,}",
+ )
+
+ print(
+ "{:21}:".format("num params painter"),
+ f"{get_num_params(trainer.G.painter):12,}",
+ )
+
+ if trainer.D is not None:
+ for d in trainer.D.keys():
+ print(
+ "{:21}:".format(f"num params discrim {d}"),
+ f"{get_num_params(trainer.D[d]):12,}",
+ )
+
+ print("-" * 35)
+
+
+def srgb2lrgb(x):
+ x = normalize(x)
+ im = ((x + 0.055) / 1.055) ** (2.4)
+ im[x <= 0.04045] = x[x <= 0.04045] / 12.92
+ return im
+
+
+def lrgb2srgb(ims):
+ if len(ims.shape) == 3:
+ ims = [ims]
+ stack = False
+ else:
+ ims = list(ims)
+ stack = True
+
+ outs = []
+ for im in ims:
+
+ out = torch.zeros_like(im)
+ for k in range(3):
+ temp = im[k, :, :]
+
+ out[k, :, :] = 12.92 * temp * (temp <= 0.0031308) + (
+ 1.055 * torch.pow(temp, (1 / 2.4)) - 0.055
+ ) * (temp > 0.0031308)
+ outs.append(out)
+
+ if stack:
+ return torch.stack(outs)
+
+ return outs[0]
+
+
+def normalize(t, mini=0, maxi=1):
+ if len(t.shape) == 3:
+ return mini + (maxi - mini) * (t - t.min()) / (t.max() - t.min())
+
+ batch_size = t.shape[0]
+ min_t = t.reshape(batch_size, -1).min(1)[0].reshape(batch_size, 1, 1, 1)
+ t = t - min_t
+ max_t = t.reshape(batch_size, -1).max(1)[0].reshape(batch_size, 1, 1, 1)
+ t = t / max_t
+ return mini + (maxi - mini) * t
+
+
+def retrieve_sky_mask(seg):
+ """
+ get the binary mask for the sky given a segmentation tensor
+ of logits (N x C x H x W) or labels (N x H x W)
+
+ Args:
+ seg (torch.Tensor): Segmentation map
+
+ Returns:
+ torch.Tensor: Sky mask
+ """
+ if len(seg.shape) == 4: # Predictions
+ seg_ind = torch.argmax(seg, dim=1)
+ else:
+ seg_ind = seg
+
+ sky_mask = seg_ind == 9
+ return sky_mask
+
+
+def all_texts_to_tensors(texts, width=640, height=40):
+ """
+ Creates a list of tensors with texts from PIL images
+
+ Args:
+ texts (list(str)): texts to write
+ width (int, optional): width of individual texts. Defaults to 640.
+ height (int, optional): height of individual texts. Defaults to 40.
+
+ Returns:
+ list(torch.Tensor): len(texts) tensors 3 x height x width
+ """
+ arrays = all_texts_to_array(texts, width, height)
+ arrays = [array.transpose(2, 0, 1) for array in arrays]
+ return [torch.tensor(array) for array in arrays]
+
+
+def write_architecture(trainer):
+ stem = "archi"
+ out = Path(trainer.opts.output_path)
+
+ # encoder
+ with open(out / f"{stem}_encoder.txt", "w") as f:
+ f.write(str(trainer.G.encoder))
+
+ # decoders
+ for k, v in trainer.G.decoders.items():
+ with open(out / f"{stem}_decoder_{k}.txt", "w") as f:
+ f.write(str(v))
+
+ # painter
+ if get_num_params(trainer.G.painter) > 0:
+ with open(out / f"{stem}_painter.txt", "w") as f:
+ f.write(str(trainer.G.painter))
+
+ # discriminators
+ if get_num_params(trainer.D) > 0:
+ for k, v in trainer.D.items():
+ with open(out / f"{stem}_discriminator_{k}.txt", "w") as f:
+ f.write(str(v))
+
+ with io.StringIO() as buf, redirect_stdout(buf):
+ print_num_parameters(trainer)
+ output = buf.getvalue()
+ with open(out / "archi_num_params.txt", "w") as f:
+ f.write(output)
+
+
+def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
+ delta = (res[0] / shape[0], res[1] / shape[1])
+ d = (shape[0] // res[0], shape[1] // res[1])
+
+ grid = (
+ torch.stack(
+ torch.meshgrid(
+ torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])
+ ),
+ dim=-1,
+ )
+ % 1
+ )
+ angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
+
+ tile_grads = (
+ lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
+ .repeat_interleave(d[0], 0)
+ .repeat_interleave(d[1], 1)
+ )
+ dot = lambda grad, shift: ( # noqa: E731
+ torch.stack(
+ (
+ grid[: shape[0], : shape[1], 0] + shift[0],
+ grid[: shape[0], : shape[1], 1] + shift[1],
+ ),
+ dim=-1,
+ )
+ * grad[: shape[0], : shape[1]]
+ ).sum(dim=-1)
+
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
+ t = fade(grid[: shape[0], : shape[1]])
+ return math.sqrt(2) * torch.lerp(
+ torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1]
+ )
+
+
+def mix_noise(x, mask, res=(8, 3), weight=0.1):
+ noise = rand_perlin_2d(x.shape[-2:], res).unsqueeze(0).unsqueeze(0).to(x.device)
+ noise = noise - noise.min()
+ mask = mask.repeat(1, 3, 1, 1).to(x.device).to(torch.float16)
+ y = mask * (weight * noise + (1 - weight) * x) + (1 - mask) * x
+ return y
+
+
+def tensor_ims_to_np_uint8s(ims):
+ """
+ transform a CHW of NCHW tensor into a list of np.uint8 [0, 255]
+ image arrays
+
+ Args:
+ ims (torch.Tensor | list): [description]
+ """
+ if not isinstance(ims, list):
+ assert isinstance(ims, torch.Tensor)
+ if ims.ndim == 3:
+ ims = [ims]
+
+ nps = []
+ for t in ims:
+ if t.shape[0] == 3:
+ t = t.permute(1, 2, 0)
+ else:
+ assert t.shape[-1] == 3
+
+ n = t.cpu().numpy()
+ n = (n + 1) / 2 * 255
+ nps.append(n.astype(np.uint8))
+
+ return nps[0] if len(nps) == 1 else nps
diff --git a/climategan/utils.py b/climategan/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0475688a757bd11e610b570f702ed8d24a63daf5
--- /dev/null
+++ b/climategan/utils.py
@@ -0,0 +1,1063 @@
+"""All non-tensor utils
+"""
+import contextlib
+import datetime
+import json
+import os
+import re
+import shutil
+import subprocess
+import time
+import traceback
+from os.path import expandvars
+from pathlib import Path
+from typing import Any, List, Optional, Union
+from uuid import uuid4
+
+import numpy as np
+import torch
+import yaml
+from addict import Dict
+from comet_ml import Experiment
+
+comet_kwargs = {
+ "auto_metric_logging": False,
+ "parse_args": True,
+ "log_env_gpu": True,
+ "log_env_cpu": True,
+ "display_summary_level": 0,
+}
+
+IMG_EXTENSIONS = set(
+ [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP"]
+)
+
+
+def resolve(path):
+ """
+ fully resolve a path:
+ resolve env vars ($HOME etc.) -> expand user (~) -> make absolute
+
+ Returns:
+ pathlib.Path: resolved absolute path
+ """
+ return Path(expandvars(str(path))).expanduser().resolve()
+
+
+def copy_run_files(opts: Dict) -> None:
+ """
+ Copy the opts's sbatch_file to output_path
+
+ Args:
+ opts (addict.Dict): options
+ """
+ if opts.sbatch_file:
+ p = resolve(opts.sbatch_file)
+ if p.exists():
+ o = resolve(opts.output_path)
+ if o.exists():
+ shutil.copyfile(p, o / p.name)
+ if opts.exp_file:
+ p = resolve(opts.exp_file)
+ if p.exists():
+ o = resolve(opts.output_path)
+ if o.exists():
+ shutil.copyfile(p, o / p.name)
+
+
+def merge(
+ source: Union[dict, Dict], destination: Union[dict, Dict]
+) -> Union[dict, Dict]:
+ """
+ run me with nosetests --with-doctest file.py
+ >>> a = { 'first' : { 'all_rows' : { 'pass' : 'dog', 'number' : '1' } } }
+ >>> b = { 'first' : { 'all_rows' : { 'fail' : 'cat', 'number' : '5' } } }
+ >>> merge(b, a) == {
+ 'first' : {
+ 'all_rows' : { '
+ pass' : 'dog',
+ 'fail' : 'cat',
+ 'number' : '5'
+ }
+ }
+ }
+ True
+ """
+ for key, value in source.items():
+ try:
+ if isinstance(value, dict):
+ # get node or create one
+ node = destination.setdefault(key, {})
+ merge(value, node)
+ else:
+ if isinstance(destination, dict):
+ destination[key] = value
+ else:
+ destination = {key: value}
+ except TypeError as e:
+ print(traceback.format_exc())
+ print(">>>", source)
+ print(">>>", destination)
+ print(">>>", key)
+ print(">>>", value)
+ raise Exception(e)
+
+ return destination
+
+
+def load_opts(
+ path: Optional[Union[str, Path]] = None,
+ default: Optional[Union[str, Path, dict, Dict]] = None,
+ commandline_opts: Optional[Union[Dict, dict]] = None,
+) -> Dict:
+ """Loadsize a configuration Dict from 2 files:
+ 1. default files with shared values across runs and users
+ 2. an overriding file with run- and user-specific values
+
+ Args:
+ path (pathlib.Path): where to find the overriding configuration
+ default (pathlib.Path, optional): Where to find the default opts.
+ Defaults to None. In which case it is assumed to be a default config
+ which needs processing such as setting default values for lambdas and gen
+ fields
+
+ Returns:
+ addict.Dict: options dictionnary, with overwritten default values
+ """
+
+ if path is None and default is None:
+ path = (
+ resolve(Path(__file__)).parent.parent
+ / "shared"
+ / "trainer"
+ / "defaults.yaml"
+ )
+
+ if path:
+ path = resolve(path)
+
+ if default is None:
+ default_opts = {}
+ else:
+ if isinstance(default, (str, Path)):
+ with open(default, "r") as f:
+ default_opts = yaml.safe_load(f)
+ else:
+ default_opts = dict(default)
+
+ if path is None:
+ overriding_opts = {}
+ else:
+ with open(path, "r") as f:
+ overriding_opts = yaml.safe_load(f) or {}
+
+ opts = Dict(merge(overriding_opts, default_opts))
+
+ if commandline_opts is not None and isinstance(commandline_opts, dict):
+ opts = Dict(merge(commandline_opts, opts))
+
+ if opts.train.kitti.pretrained:
+ assert "kitti" in opts.data.files.train
+ assert "kitti" in opts.data.files.val
+ assert opts.train.kitti.epochs > 0
+
+ opts.domains = []
+ if "m" in opts.tasks or "s" in opts.tasks or "d" in opts.tasks:
+ opts.domains.extend(["r", "s"])
+ if "p" in opts.tasks:
+ opts.domains.append("rf")
+ if opts.train.kitti.pretrain:
+ opts.domains.append("kitti")
+
+ opts.domains = list(set(opts.domains))
+
+ if "s" in opts.tasks:
+ if opts.gen.encoder.architecture != opts.gen.s.architecture:
+ print(
+ "WARNING: segmentation encoder and decoder architectures do not match"
+ )
+ print(
+ "Encoder: {} <> Decoder: {}".format(
+ opts.gen.encoder.architecture, opts.gen.s.architecture
+ )
+ )
+ if opts.gen.m.use_spade:
+ if "d" not in opts.tasks or "s" not in opts.tasks:
+ raise ValueError(
+ "opts.gen.m.use_spade is True so tasks MUST include"
+ + "both d and s, but received {}".format(opts.tasks)
+ )
+ if opts.gen.d.classify.enable:
+ raise ValueError(
+ "opts.gen.m.use_spade is True but using D as a classifier"
+ + " which is a non-implemented combination"
+ )
+
+ if opts.gen.s.depth_feat_fusion is True or opts.gen.s.depth_dada_fusion is True:
+ opts.gen.s.use_dada = True
+
+ events_path = (
+ resolve(Path(__file__)).parent.parent / "shared" / "trainer" / "events.yaml"
+ )
+ if events_path.exists():
+ with events_path.open("r") as f:
+ events_dict = yaml.safe_load(f)
+ events_dict = Dict(events_dict)
+ opts.events = events_dict
+
+ return set_data_paths(opts)
+
+
+def set_data_paths(opts: Dict) -> Dict:
+ """Update the data files paths in data.files.train and data.files.val
+ from data.files.base
+
+ Args:
+ opts (addict.Dict): options
+
+ Returns:
+ addict.Dict: updated options
+ """
+
+ for mode in ["train", "val"]:
+ for domain in opts.data.files[mode]:
+ if opts.data.files.base and not opts.data.files[mode][domain].startswith(
+ "/"
+ ):
+ opts.data.files[mode][domain] = str(
+ Path(opts.data.files.base) / opts.data.files[mode][domain]
+ )
+ assert Path(
+ opts.data.files[mode][domain]
+ ).exists(), "Cannot find {}".format(str(opts.data.files[mode][domain]))
+
+ return opts
+
+
+def load_test_opts(test_file_path: str = "config/trainer/local_tests.yaml") -> Dict:
+ """Returns the special opts set up for local tests
+ Args:
+ test_file_path (str, optional): Name of the file located in config/
+ Defaults to "local_tests.yaml".
+
+ Returns:
+ addict.Dict: Opts loaded from defaults.yaml and updated from test_file_path
+ """
+ return load_opts(
+ Path(__file__).parent.parent / f"{test_file_path}",
+ default=Path(__file__).parent.parent / "shared/trainer/defaults.yaml",
+ )
+
+
+def get_git_revision_hash() -> str:
+ """Get current git hash the code is run from
+
+ Returns:
+ str: git hash
+ """
+ try:
+ return subprocess.check_output(["git", "rev-parse", "HEAD"]).decode().strip()
+ except Exception as e:
+ return str(e)
+
+
+def get_git_branch() -> str:
+ """Get current git branch name
+
+ Returns:
+ str: git branch name
+ """
+ try:
+ return (
+ subprocess.check_output(["git", "rev-parse", "--abbrev-ref", "HEAD"])
+ .decode()
+ .strip()
+ )
+ except Exception as e:
+ return str(e)
+
+
+def kill_job(id: Union[int, str]) -> None:
+ subprocess.check_output(["scancel", str(id)])
+
+
+def write_hash(path: Union[str, Path]) -> None:
+ hash_code = get_git_revision_hash()
+ with open(path, "w") as f:
+ f.write(hash_code)
+
+
+def shortuid():
+ return str(uuid4()).split("-")[0]
+
+
+def datenowshort():
+ """
+ >>> a = str(datetime.datetime.now())
+ >>> print(a)
+ '2021-02-25 11:34:50.188072'
+ >>> print(a[5:].split(".")[0].replace(" ", "_"))
+ '02-25_11:35:41'
+
+ Returns:
+ str: month-day_h:m:s
+ """
+ return str(datetime.datetime.now())[5:].split(".")[0].replace(" ", "_")
+
+
+def get_increased_path(path: Union[str, Path], use_date: bool = False) -> Path:
+ """Returns an increased path: if dir exists, returns `dir (1)`.
+ If `dir (i)` exists, returns `dir (max(i) + 1)`
+
+ get_increased_path("test").mkdir() creates `test/`
+ then
+ get_increased_path("test").mkdir() creates `test (1)/`
+ etc.
+ if `test (3)/` exists but not `test (2)/`, `test (4)/` is created so that indexes
+ always increase
+
+ Args:
+ path (str or pathlib.Path): the file/directory which may already exist and would
+ need to be increased
+
+ Returns:
+ pathlib.Path: increased path
+ """
+ fp = resolve(path)
+ if not fp.exists():
+ return fp
+
+ if fp.is_file():
+ if not use_date:
+ while fp.exists():
+ fp = fp.parent / f"{fp.stem}--{shortuid()}{fp.suffix}"
+ return fp
+ else:
+ while fp.exists():
+ time.sleep(0.5)
+ fp = fp.parent / f"{fp.stem}--{datenowshort()}{fp.suffix}"
+ return fp
+
+ if not use_date:
+ while fp.exists():
+ fp = fp.parent / f"{fp.name}--{shortuid()}"
+ return fp
+ else:
+ while fp.exists():
+ time.sleep(0.5)
+ fp = fp.parent / f"{fp.name}--{datenowshort()}"
+ return fp
+
+ # vals = []
+ # for n in fp.parent.glob("{}*".format(fp.stem)):
+ # if re.match(r".+\(\d+\)", str(n.name)) is not None:
+ # name = str(n.name)
+ # start = name.index("(")
+ # end = name.index(")")
+ # vals.append(int(name[start + 1 : end]))
+ # if vals:
+ # ext = " ({})".format(max(vals) + 1)
+ # elif fp.exists():
+ # ext = " (1)"
+ # else:
+ # ext = ""
+ # return fp.parent / (fp.stem + ext + fp.suffix)
+
+
+def env_to_path(path: str) -> str:
+ """Transorms an environment variable mention in a json
+ into its actual value. E.g. $HOME/clouds -> /home/vsch/clouds
+
+ Args:
+ path (str): path potentially containing the env variable
+
+ """
+ path_elements = path.split("/")
+ new_path = []
+ for el in path_elements:
+ if "$" in el:
+ new_path.append(os.environ[el.replace("$", "")])
+ else:
+ new_path.append(el)
+ return "/".join(new_path)
+
+
+def flatten_opts(opts: Dict) -> dict:
+ """Flattens a multi-level addict.Dict or native dictionnary into a single
+ level native dict with string keys representing the keys sequence to reach
+ a value in the original argument.
+
+ d = addict.Dict()
+ d.a.b.c = 2
+ d.a.b.d = 3
+ d.a.e = 4
+ d.f = 5
+ flatten_opts(d)
+ >>> {
+ "a.b.c": 2,
+ "a.b.d": 3,
+ "a.e": 4,
+ "f": 5,
+ }
+
+ Args:
+ opts (addict.Dict or dict): addict dictionnary to flatten
+
+ Returns:
+ dict: flattened dictionnary
+ """
+ values_list = []
+
+ def p(d, prefix="", vals=[]):
+ for k, v in d.items():
+ if isinstance(v, (Dict, dict)):
+ p(v, prefix + k + ".", vals)
+ elif isinstance(v, list):
+ if v and isinstance(v[0], (Dict, dict)):
+ for i, m in enumerate(v):
+ p(m, prefix + k + "." + str(i) + ".", vals)
+ else:
+ vals.append((prefix + k, str(v)))
+ else:
+ if isinstance(v, Path):
+ v = str(v)
+ vals.append((prefix + k, v))
+
+ p(opts, vals=values_list)
+ return dict(values_list)
+
+
+def get_comet_rest_api_key(
+ path_to_config_file: Optional[Union[str, Path]] = None
+) -> str:
+ """Gets a comet.ml rest_api_key in the following order:
+ * config file specified as argument
+ * environment variable
+ * .comet.config file in the current working diretory
+ * .comet.config file in your home
+
+ config files must have a line like `rest_api_key=`
+
+ Args:
+ path_to_config_file (str or pathlib.Path, optional): config_file to use.
+ Defaults to None.
+
+ Raises:
+ ValueError: can't find a file
+ ValueError: can't find the key in a file
+
+ Returns:
+ str: your comet rest_api_key
+ """
+ if "COMET_REST_API_KEY" in os.environ and path_to_config_file is None:
+ return os.environ["COMET_REST_API_KEY"]
+ if path_to_config_file is not None:
+ p = resolve(path_to_config_file)
+ else:
+ p = Path() / ".comet.config"
+ if not p.exists():
+ p = Path.home() / ".comet.config"
+ if not p.exists():
+ raise ValueError("Unable to find your COMET_REST_API_KEY")
+ with p.open("r") as f:
+ for keys in f:
+ if "rest_api_key" in keys:
+ return keys.strip().split("=")[-1].strip()
+ raise ValueError("Unable to find your COMET_REST_API_KEY in {}".format(str(p)))
+
+
+def get_files(dirName: str) -> list:
+ # create a list of file and sub directories
+ files = sorted(os.listdir(dirName))
+ all_files = list()
+ for entry in files:
+ fullPath = os.path.join(dirName, entry)
+ if os.path.isdir(fullPath):
+ all_files = all_files + get_files(fullPath)
+ else:
+ all_files.append(fullPath)
+
+ return all_files
+
+
+def make_json_file(
+ tasks: List[str],
+ addresses: List[str], # for windows user, use "\\" instead of using "/"
+ json_names: List[str] = ["train_jsonfile.json", "val_jsonfile.json"],
+ splitter: str = "/",
+ pourcentage_val: float = 0.15,
+) -> None:
+ """
+ How to use it?
+ e.g.
+ make_json_file(['x','m','d'], [
+ '/network/tmp1/ccai/data/munit_dataset/trainA_size_1200/',
+ '/network/tmp1/ccai/data/munit_dataset/seg_trainA_size_1200/',
+ '/network/tmp1/ccai/data/munit_dataset/trainA_megadepth_resized/'
+ ], ["train_r.json", "val_r.json"])
+
+ Args:
+ tasks (list): the list of image type like 'x', 'm', 'd', etc.
+ addresses (list): the list of the corresponding address of the
+ image type mentioned in tasks
+ json_names (list): names for the json files, train being first
+ (e.g. : ["train_r.json", "val_r.json"])
+ splitter (str, optional): The path separator for the current OS.
+ Defaults to '/'.
+ pourcentage_val: pourcentage of files to go in validation set
+ """
+ assert len(tasks) == len(addresses), "keys and addresses must have the same length!"
+
+ files = [get_files(addresses[j]) for j in range(len(tasks))]
+ n_files_val = int(pourcentage_val * len(files[0]))
+ n_files_train = len(files[0]) - n_files_val
+ filenames = [files[0][:n_files_train], files[0][-n_files_val:]]
+
+ file_address_map = {
+ tasks[j]: {
+ ".".join(file.split(splitter)[-1].split(".")[:-1]): file
+ for file in files[j]
+ }
+ for j in range(len(tasks))
+ }
+ # The tasks of the file_address_map are like 'x', 'm', 'd'...
+ # The values of the file_address_map are a dictionary whose tasks are the
+ # filenames without extension whose values are the path of the filename
+ # e.g. file_address_map =
+ # {'x': {'A': 'path/to/trainA_size_1200/A.png', ...},
+ # 'm': {'A': 'path/to/seg_trainA_size_1200/A.jpg',...}
+ # 'd': {'A': 'path/to/trainA_megadepth_resized/A.bmp',...}
+ # ...}
+
+ for i, json_name in enumerate(json_names):
+ dicts = []
+ for j in range(len(filenames[i])):
+ file = filenames[i][j]
+ filename = file.split(splitter)[-1] # the filename with 'x' extension
+ filename_ = ".".join(
+ filename.split(".")[:-1]
+ ) # the filename without extension
+ tmp_dict = {}
+ for k in range(len(tasks)):
+ tmp_dict[tasks[k]] = file_address_map[tasks[k]][filename_]
+ dicts.append(tmp_dict)
+ with open(json_name, "w", encoding="utf-8") as outfile:
+ json.dump(dicts, outfile, ensure_ascii=False)
+
+
+def append_task_to_json(
+ path_to_json: Union[str, Path],
+ path_to_new_json: Union[str, Path],
+ path_to_new_images_dir: Union[str, Path],
+ new_task_name: str,
+):
+ """Add all files for a task to an existing json file by creating a new json file
+ in the specified path.
+ Assumes that the files for the new task have exactly the same names as the ones
+ for the other tasks
+
+ Args:
+ path_to_json: complete path to the json file to modify
+ path_to_new_json: complete path to the new json file to be created
+ path_to_new_images_dir: complete path of the directory where to find the
+ images for the new task
+ new_task_name: name of the new task
+
+ e.g:
+ append_json(
+ "/network/tmp1/ccai/data/climategan/seg/train_r.json",
+ "/network/tmp1/ccai/data/climategan/seg/train_r_new.json"
+ "/network/tmp1/ccai/data/munit_dataset/trainA_seg_HRNet/unity_labels",
+ "s",
+ )
+ """
+ ims_list = None
+ if path_to_json:
+ path_to_json = Path(path_to_json).resolve()
+ with open(path_to_json, "r") as f:
+ ims_list = json.load(f)
+
+ files = get_files(path_to_new_images_dir)
+
+ if ims_list is None:
+ raise ValueError(f"Could not find the list in {path_to_json}")
+
+ new_ims_list = [None] * len(ims_list)
+ for i, im_dict in enumerate(ims_list):
+ new_ims_list[i] = {}
+ for task, path in im_dict.items():
+ new_ims_list[i][task] = path
+
+ for i, im_dict in enumerate(ims_list):
+ for task, path in im_dict.items():
+ file_name = os.path.splitext(path)[0] # removes extension
+ file_name = file_name.rsplit("/", 1)[-1] # only the file_name
+ file_found = False
+ for file_path in files:
+ if file_name in file_path:
+ file_found = True
+ new_ims_list[i][new_task_name] = file_path
+ break
+ if file_found:
+ break
+ else:
+ print("Error! File ", file_name, "not found in directory!")
+ return
+
+ with open(path_to_new_json, "w", encoding="utf-8") as f:
+ json.dump(new_ims_list, f, ensure_ascii=False)
+
+
+def sum_dict(dict1: Union[dict, Dict], dict2: Union[Dict, dict]) -> Union[dict, Dict]:
+ """Add dict2 into dict1"""
+ for k, v in dict2.items():
+ if not isinstance(v, dict):
+ dict1[k] += v
+ else:
+ sum_dict(dict1[k], dict2[k])
+ return dict1
+
+
+def div_dict(dict1: Union[dict, Dict], div_by: float) -> dict:
+ """Divide elements of dict1 by div_by"""
+ for k, v in dict1.items():
+ if not isinstance(v, dict):
+ dict1[k] /= div_by
+ else:
+ div_dict(dict1[k], div_by)
+ return dict1
+
+
+def comet_id_from_url(url: str) -> Optional[str]:
+ """
+ Get comet exp id from its url:
+ https://www.comet.ml/vict0rsch/climategan/2a1a4a96afe848218c58ac4e47c5375f
+ -> 2a1a4a96afe848218c58ac4e47c5375f
+
+ Args:
+ url (str): comet exp url
+
+ Returns:
+ str: comet exp id
+ """
+ try:
+ ids = url.split("/")
+ ids = [i for i in ids if i]
+ return ids[-1]
+ except Exception:
+ return None
+
+
+@contextlib.contextmanager
+def temp_np_seed(seed: Optional[int]) -> None:
+ """
+ Set temporary numpy seed:
+ with temp_np_seed(123):
+ np.random.permutation(3)
+
+ Args:
+ seed (int): temporary numpy seed
+ """
+ state = np.random.get_state()
+ np.random.seed(seed)
+ try:
+ yield
+ finally:
+ np.random.set_state(state)
+
+
+def get_display_indices(opts: Dict, domain: str, length: int) -> list:
+ """
+ Compute the index of images to use for comet logging:
+ if opts.comet.display_indices is an int, and domain is real:
+ return range(int)
+ if opts.comet.display_indices is an int, and domain is sim:
+ return permutation(length)[:int]
+ if opts.comet.display_indices is a list:
+ return list
+
+ otherwise return []
+
+
+ Args:
+ opts (addict.Dict): options
+ domain (str): domain for those indices
+ length (int): length of dataset for the permutation
+
+ Returns:
+ list(int): The indices to display
+ """
+ if domain == "rf":
+ dsize = max([opts.comet.display_size, opts.train.fid.get("n_images", 0)])
+ else:
+ dsize = opts.comet.display_size
+ if dsize > length:
+ print(
+ f"Warning: dataset is smaller ({length} images) "
+ + f"than required display indices ({dsize})."
+ + f" Selecting {length} images."
+ )
+
+ display_indices = []
+ assert isinstance(dsize, (int, list)), "Unknown display size {}".format(dsize)
+ if isinstance(dsize, int):
+ assert dsize >= 0, "Display size cannot be < 0"
+ with temp_np_seed(123):
+ display_indices = list(np.random.permutation(length)[:dsize])
+ elif isinstance(dsize, list):
+ display_indices = dsize
+
+ if not display_indices:
+ print("Warning: no display indices (utils.get_display_indices)")
+
+ return display_indices
+
+
+def get_latest_path(path: Union[str, Path]) -> Path:
+ """
+ Get the file/dir with largest increment i as `file (i).ext`
+
+ Args:
+ path (str or pathlib.Path): base pattern
+
+ Returns:
+ Path: path found
+ """
+ p = Path(path).resolve()
+ s = p.stem
+ e = p.suffix
+ files = list(p.parent.glob(f"{s}*(*){e}"))
+ indices = list(p.parent.glob(f"{s}*(*){e}"))
+ indices = list(map(lambda f: f.name, indices))
+ indices = list(map(lambda x: re.findall(r"\((.*?)\)", x)[-1], indices))
+ indices = list(map(int, indices))
+ if not indices:
+ f = p
+ else:
+ f = files[np.argmax(indices)]
+ return f
+
+
+def get_existing_jobID(output_path: Path) -> str:
+ """
+ If the opts in output_path have a jobID, return it. Else, return None
+
+ Args:
+ output_path (pathlib.Path | str): where to look
+
+ Returns:
+ str | None: jobid
+ """
+ op = Path(output_path)
+ if not op.exists():
+ return
+
+ opts_path = get_latest_path(op / "opts.yaml")
+
+ if not opts_path.exists():
+ return
+
+ with opts_path.open("r") as f:
+ opts = yaml.safe_load(f)
+
+ jobID = opts.get("jobID", None)
+
+ return jobID
+
+
+def find_existing_training(opts: Dict) -> Optional[Path]:
+ """
+ Looks in all directories like output_path.parent.glob(output_path.name*)
+ and compares the logged slurm job id with the current opts.jobID
+
+ If a match is found, the training should automatically continue in the
+ matching output directory
+
+ If no match is found, this is a new job and it should have a new output path
+
+ Args:
+ opts (Dict): trainer's options
+
+ Returns:
+ Optional[Path]: a path if a matchin jobID is found, None otherwise
+ """
+ if opts.jobID is None:
+ print("WARNING: current JOBID is None")
+ return
+
+ print("---------- Current job id:", opts.jobID)
+
+ path = Path(opts.output_path).resolve()
+ parent = path.parent
+ name = path.name
+
+ try:
+ similar_dirs = [p.resolve() for p in parent.glob(f"{name}*") if p.is_dir()]
+
+ for sd in similar_dirs:
+ candidate_jobID = get_existing_jobID(sd)
+ if candidate_jobID is not None and str(opts.jobID) == str(candidate_jobID):
+ print(f"Found matching job id in {sd}\n")
+ return sd
+ print("Did not find a matching job id in \n {}\n".format(str(similar_dirs)))
+ except Exception as e:
+ print("ERROR: Could not resume (find_existing_training)", e)
+
+
+def pprint(*args: List[Any]):
+ """
+ Prints *args within a box of "=" characters
+ """
+ txt = " ".join(map(str, args))
+ col = "====="
+ space = " "
+ head_size = 2
+ header = "\n".join(["=" * (len(txt) + 2 * (len(col) + len(space)))] * head_size)
+ empty = "{}{}{}{}{}".format(col, space, " " * (len(txt)), space, col)
+ print()
+ print(header)
+ print(empty)
+ print("{}{}{}{}{}".format(col, space, txt, space, col))
+ print(empty)
+ print(header)
+ print()
+
+
+def get_existing_comet_id(path: str) -> Optional[str]:
+ """
+ Returns the id of the existing comet experiment stored in path
+
+ Args:
+ path (str): Output pat where to look for the comet exp
+
+ Returns:
+ Optional[str]: comet exp's ID if any was found
+ """
+ comet_previous_path = get_latest_path(Path(path) / "comet_url.txt")
+ if comet_previous_path.exists():
+ with comet_previous_path.open("r") as f:
+ url = f.read().strip()
+ return comet_id_from_url(url)
+
+
+def get_latest_opts(path):
+ """
+ get latest opts dumped in path if they look like *opts*.yaml
+ and were increased as
+ opts.yaml < opts (1).yaml < opts (2).yaml etc.
+
+ Args:
+ path (str or pathlib.Path): where to look for opts
+
+ Raises:
+ ValueError: If no match for *opts*.yaml is found
+
+ Returns:
+ addict.Dict: loaded opts
+ """
+ path = Path(path)
+ opts = get_latest_path(path / "opts.yaml")
+ assert opts.exists()
+ with opts.open("r") as f:
+ opts = Dict(yaml.safe_load(f))
+
+ events_path = Path(__file__).parent.parent / "shared" / "trainer" / "events.yaml"
+ if events_path.exists():
+ with events_path.open("r") as f:
+ events_dict = yaml.safe_load(f)
+ events_dict = Dict(events_dict)
+ opts.events = events_dict
+
+ return opts
+
+
+def text_to_array(text, width=640, height=40):
+ """
+ Creates a numpy array of shape height x width x 3 with
+ text written on it using PIL
+
+ Args:
+ text (str): text to write
+ width (int, optional): Width of the resulting array. Defaults to 640.
+ height (int, optional): Height of the resulting array. Defaults to 40.
+
+ Returns:
+ np.ndarray: Centered text
+ """
+ from PIL import Image, ImageDraw, ImageFont
+
+ img = Image.new("RGB", (width, height), (255, 255, 255))
+ try:
+ font = ImageFont.truetype("UnBatang.ttf", 25)
+ except OSError:
+ font = ImageFont.load_default()
+
+ d = ImageDraw.Draw(img)
+ text_width, text_height = d.textsize(text)
+ h = 40 // 2 - 3 * text_height // 2
+ w = width // 2 - text_width
+ d.text((w, h), text, font=font, fill=(30, 30, 30))
+ return np.array(img)
+
+
+def all_texts_to_array(texts, width=640, height=40):
+ """
+ Creates an array of texts, each of height and width specified
+ by the args, concatenated along their width dimension
+
+ Args:
+ texts (list(str)): List of texts to concatenate
+ width (int, optional): Individual text's width. Defaults to 640.
+ height (int, optional): Individual text's height. Defaults to 40.
+
+ Returns:
+ list: len(texts) text arrays with dims height x width x 3
+ """
+ return [text_to_array(text, width, height) for text in texts]
+
+
+class Timer:
+ def __init__(self, name="", store=None, precision=3, ignore=False, cuda=True):
+ self.name = name
+ self.store = store
+ self.precision = precision
+ self.ignore = ignore
+ self.cuda = cuda
+
+ if cuda:
+ self._start_event = torch.cuda.Event(enable_timing=True)
+ self._end_event = torch.cuda.Event(enable_timing=True)
+
+ def format(self, n):
+ return f"{n:.{self.precision}f}"
+
+ def __enter__(self):
+ """Start a new timer as a context manager"""
+ if self.cuda:
+ self._start_event.record()
+ else:
+ self._start_time = time.perf_counter()
+ return self
+
+ def __exit__(self, *exc_info):
+ """Stop the context manager timer"""
+ if self.ignore:
+ return
+
+ if self.cuda:
+ self._end_event.record()
+ torch.cuda.synchronize()
+ new_time = self._start_event.elapsed_time(self._end_event) / 1000
+ else:
+ t = time.perf_counter()
+ new_time = t - self._start_time
+
+ if self.store is not None:
+ assert isinstance(self.store, list)
+ self.store.append(new_time)
+ if self.name:
+ print(f"[{self.name}] Elapsed time: {self.format(new_time)}")
+
+
+def get_loader_output_shape_from_opts(opts):
+ transforms = opts.data.transforms
+
+ t = None
+ for t in transforms[::-1]:
+ if t.name == "resize":
+ break
+ assert t is not None
+
+ if isinstance(t.new_size, Dict):
+ return {
+ task: (
+ t.new_size.get(task, t.new_size.default),
+ t.new_size.get(task, t.new_size.default),
+ )
+ for task in opts.tasks + ["x"]
+ }
+ assert isinstance(t.new_size, int)
+ new_size = (t.new_size, t.new_size)
+ return {task: new_size for task in opts.tasks + ["x"]}
+
+
+def find_target_size(opts, task):
+ target_size = None
+ if isinstance(opts.data.transforms[-1].new_size, int):
+ target_size = opts.data.transforms[-1].new_size
+ else:
+ if task in opts.data.transforms[-1].new_size:
+ target_size = opts.data.transforms[-1].new_size[task]
+ else:
+ assert "default" in opts.data.transforms[-1].new_size
+ target_size = opts.data.transforms[-1].new_size["default"]
+
+ return target_size
+
+
+def to_128(im, w_target=-1):
+ h, w = im.shape[:2]
+ aspect_ratio = h / w
+ if w_target < 0:
+ w_target = w
+
+ nw = int(w_target / 128) * 128
+ nh = int(nw * aspect_ratio / 128) * 128
+
+ return nh, nw
+
+
+def is_image_file(filename):
+ """Check that a file's name points to a known image format"""
+ if isinstance(filename, Path):
+ return filename.suffix in IMG_EXTENSIONS
+
+ return Path(filename).suffix in IMG_EXTENSIONS
+
+
+def find_images(path, recursive=False):
+ """
+ Get a list of all images contained in a directory:
+
+ - path.glob("*") if not recursive
+ - path.glob("**/*") if recursive
+ """
+ p = Path(path)
+ assert p.exists()
+ assert p.is_dir()
+ pattern = "*"
+ if recursive:
+ pattern += "*/*"
+
+ return [i for i in p.glob(pattern) if i.is_file() and is_image_file(i)]
+
+
+def cols():
+ try:
+ col = os.get_terminal_size().columns
+ except Exception:
+ col = 50
+ return col
+
+
+def upload_images_to_exp(
+ path, exp=None, project_name="climategan-eval", sleep=-1, verbose=0
+):
+ ims = find_images(path)
+ end = None
+ c = cols()
+ if verbose == 1:
+ end = "\r"
+ if verbose > 1:
+ end = "\n"
+ if exp is None:
+ exp = Experiment(project_name=project_name)
+ for im in ims:
+ exp.log_image(str(im))
+ if verbose > 0:
+ if verbose == 1:
+ print(" " * (c - 1), end="\r", flush=True)
+ print(str(im), end=end, flush=True)
+ if sleep > 0:
+ time.sleep(sleep)
+ return exp
diff --git a/config/model/masker/.ipynb_checkpoints/opts-checkpoint.yaml b/config/model/masker/.ipynb_checkpoints/opts-checkpoint.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c36d19b93b696a4396d013876b98e04e44b9277a
--- /dev/null
+++ b/config/model/masker/.ipynb_checkpoints/opts-checkpoint.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e8fd82a0d6c1de82a4ec0c1f70d0a7d3533b603a4d1ecf1c4a93d0e48aa94c31
+size 6730
diff --git a/config/model/masker/opts.yaml b/config/model/masker/opts.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c36d19b93b696a4396d013876b98e04e44b9277a
--- /dev/null
+++ b/config/model/masker/opts.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e8fd82a0d6c1de82a4ec0c1f70d0a7d3533b603a4d1ecf1c4a93d0e48aa94c31
+size 6730
diff --git a/config/model/painter/opts.yaml b/config/model/painter/opts.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..05a22b31e9ac6171ce6035991d2293913c5da6a8
--- /dev/null
+++ b/config/model/painter/opts.yaml
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:743b4cb46c6c62c424e348fec0093171cece006547deaec66f5324937bab4c13
+size 5329
diff --git a/eval_masker.py b/eval_masker.py
new file mode 100644
index 0000000000000000000000000000000000000000..a284d820f5f2ffec971b0784189503ebbc1a0c08
--- /dev/null
+++ b/eval_masker.py
@@ -0,0 +1,796 @@
+"""
+Compute metrics of the performance of the masker using a set of ground-truth labels
+
+run eval_masker.py --model "/miniscratch/_groups/ccai/checkpoints/model/"
+
+"""
+print("Imports...", end="")
+import os
+import os.path
+from argparse import ArgumentParser
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from comet_ml import Experiment
+import torch
+import yaml
+from skimage.color import rgba2rgb
+from skimage.io import imread, imsave
+from skimage.transform import resize
+from skimage.util import img_as_ubyte
+from torchvision.transforms import ToTensor
+
+from climategan.data import encode_mask_label
+from climategan.eval_metrics import (
+ masker_classification_metrics,
+ get_confusion_matrix,
+ edges_coherence_std_min,
+ boxplot_metric,
+ clustermap_metric,
+)
+from climategan.transforms import PrepareTest
+from climategan.trainer import Trainer
+from climategan.utils import find_images
+
+dict_metrics = {
+ "names": {
+ "tpr": "TPR, Recall, Sensitivity",
+ "tnr": "TNR, Specificity, Selectivity",
+ "fpr": "FPR",
+ "fpt": "False positives relative to image size",
+ "fnr": "FNR, Miss rate",
+ "fnt": "False negatives relative to image size",
+ "mpr": "May positive rate (MPR)",
+ "mnr": "May negative rate (MNR)",
+ "accuracy": "Accuracy (ignoring may)",
+ "error": "Error (ignoring may)",
+ "f05": "F0.05 score",
+ "precision": "Precision",
+ "edge_coherence": "Edge coherence",
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
+ },
+ "threshold": {
+ "tpr": 0.95,
+ "tnr": 0.95,
+ "fpr": 0.05,
+ "fpt": 0.01,
+ "fnr": 0.05,
+ "fnt": 0.01,
+ "accuracy": 0.95,
+ "error": 0.05,
+ "f05": 0.95,
+ "precision": 0.95,
+ "edge_coherence": 0.02,
+ "accuracy_must_may": 0.5,
+ },
+ "key_metrics": ["f05", "error", "edge_coherence", "mnr"],
+}
+
+print("Ok.")
+
+
+def parsed_args():
+ """Parse and returns command-line args
+
+ Returns:
+ argparse.Namespace: the parsed arguments
+ """
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--model",
+ type=str,
+ help="Path to a pre-trained model",
+ )
+ parser.add_argument(
+ "--images_dir",
+ default="/miniscratch/_groups/ccai/data/omnigan/masker-test-set/imgs",
+ type=str,
+ help="Directory containing the original test images",
+ )
+ parser.add_argument(
+ "--labels_dir",
+ default="/miniscratch/_groups/ccai/data/omnigan/masker-test-set/labels",
+ type=str,
+ help="Directory containing the labeled images",
+ )
+ parser.add_argument(
+ "--image_size",
+ default=640,
+ type=int,
+ help="The height and weight of the pre-processed images",
+ )
+ parser.add_argument(
+ "--max_files",
+ default=-1,
+ type=int,
+ help="Limit loaded samples",
+ )
+ parser.add_argument(
+ "--bin_value", default=0.5, type=float, help="Mask binarization threshold"
+ )
+ parser.add_argument(
+ "-y",
+ "--yaml",
+ default=None,
+ type=str,
+ help="load a yaml file to parametrize the evaluation",
+ )
+ parser.add_argument(
+ "-t", "--tags", nargs="*", help="Comet.ml tags", default=[], type=str
+ )
+ parser.add_argument(
+ "-p",
+ "--plot",
+ action="store_true",
+ default=False,
+ help="Plot masker images & their metrics overlays",
+ )
+ parser.add_argument(
+ "--no_paint",
+ action="store_true",
+ default=False,
+ help="Do not log painted images",
+ )
+ parser.add_argument(
+ "--write_metrics",
+ action="store_true",
+ default=False,
+ help="If True, write CSV file and maps images in model's path directory",
+ )
+ parser.add_argument(
+ "--load_metrics",
+ action="store_true",
+ default=False,
+ help="If True, load predictions and metrics instead of re-computing",
+ )
+ parser.add_argument(
+ "--prepare_torch",
+ action="store_true",
+ default=False,
+ help="If True, pre-process images as torch tensors",
+ )
+ parser.add_argument(
+ "--output_csv",
+ default=None,
+ type=str,
+ help="Filename of the output CSV with the metrics of all models",
+ )
+
+ return parser.parse_args()
+
+
+def uint8(array):
+ return array.astype(np.uint8)
+
+
+def crop_and_resize(image_path, label_path):
+ """
+ Resizes an image so that it keeps the aspect ratio and the smallest dimensions
+ is 640, then crops this resized image in its center so that the output is 640x640
+ without aspect ratio distortion
+
+ Args:
+ image_path (Path or str): Path to an image
+ label_path (Path or str): Path to the image's associated label
+
+ Returns:
+ tuple((np.ndarray, np.ndarray)): (new image, new label)
+ """
+
+ img = imread(image_path)
+ lab = imread(label_path)
+
+ # if img.shape[-1] == 4:
+ # img = uint8(rgba2rgb(img) * 255)
+
+ # TODO: remove (debug)
+ if img.shape[:2] != lab.shape[:2]:
+ print(
+ "\nWARNING: shape mismatch: im -> ({}) {}, lab -> ({}) {}".format(
+ img.shape[:2], image_path.name, lab.shape[:2], label_path.name
+ )
+ )
+ # breakpoint()
+
+ # resize keeping aspect ratio: smallest dim is 640
+ i_h, i_w = img.shape[:2]
+ if i_h < i_w:
+ i_size = (640, int(640 * i_w / i_h))
+ else:
+ i_size = (int(640 * i_h / i_w), 640)
+
+ l_h, l_w = img.shape[:2]
+ if l_h < l_w:
+ l_size = (640, int(640 * l_w / l_h))
+ else:
+ l_size = (int(640 * l_h / l_w), 640)
+
+ r_img = resize(img, i_size, preserve_range=True, anti_aliasing=True)
+ r_img = uint8(r_img)
+
+ r_lab = resize(lab, l_size, preserve_range=True, anti_aliasing=False, order=0)
+ r_lab = uint8(r_lab)
+
+ # crop in the center
+ H, W = r_img.shape[:2]
+
+ top = (H - 640) // 2
+ left = (W - 640) // 2
+
+ rc_img = r_img[top : top + 640, left : left + 640, :]
+ rc_lab = (
+ r_lab[top : top + 640, left : left + 640, :]
+ if r_lab.ndim == 3
+ else r_lab[top : top + 640, left : left + 640]
+ )
+
+ return rc_img, rc_lab
+
+
+def plot_images(
+ output_filename,
+ img,
+ label,
+ pred,
+ metrics_dict,
+ maps_dict,
+ edge_coherence=-1,
+ pred_edge=None,
+ label_edge=None,
+ dpi=300,
+ alpha=0.5,
+ vmin=0.0,
+ vmax=1.0,
+ fontsize="xx-small",
+ cmap={
+ "fp": "Reds",
+ "fn": "Reds",
+ "may_neg": "Oranges",
+ "may_pos": "Purples",
+ "pred": "Greens",
+ },
+):
+ f, axes = plt.subplots(1, 5, dpi=dpi)
+
+ # FPR (predicted mask on cannot flood)
+ axes[0].imshow(img)
+ fp_map_plt = axes[0].imshow( # noqa: F841
+ maps_dict["fp"], vmin=vmin, vmax=vmax, cmap=cmap["fp"], alpha=alpha
+ )
+ axes[0].axis("off")
+ axes[0].set_title("FPR: {:.4f}".format(metrics_dict["fpr"]), fontsize=fontsize)
+
+ # FNR (missed mask on must flood)
+ axes[1].imshow(img)
+ fn_map_plt = axes[1].imshow( # noqa: F841
+ maps_dict["fn"], vmin=vmin, vmax=vmax, cmap=cmap["fn"], alpha=alpha
+ )
+ axes[1].axis("off")
+ axes[1].set_title("FNR: {:.4f}".format(metrics_dict["fnr"]), fontsize=fontsize)
+
+ # May flood
+ axes[2].imshow(img)
+ if edge_coherence != -1:
+ title = "MNR: {:.2f} | MPR: {:.2f}\nEdge coh.: {:.4f}".format(
+ metrics_dict["mnr"], metrics_dict["mpr"], edge_coherence
+ )
+ # alpha_here = alpha / 4.
+ # pred_edge_plt = axes[2].imshow(
+ # 1.0 - pred_edge, cmap="gray", alpha=alpha_here
+ # )
+ # label_edge_plt = axes[2].imshow(
+ # 1.0 - label_edge, cmap="gray", alpha=alpha_here
+ # )
+ else:
+ title = "MNR: {:.2f} | MPR: {:.2f}".format(mnr, mpr) # noqa: F821
+ # alpha_here = alpha / 2.
+ may_neg_map_plt = axes[2].imshow( # noqa: F841
+ maps_dict["may_neg"], vmin=vmin, vmax=vmax, cmap=cmap["may_neg"], alpha=alpha
+ )
+ may_pos_map_plt = axes[2].imshow( # noqa: F841
+ maps_dict["may_pos"], vmin=vmin, vmax=vmax, cmap=cmap["may_pos"], alpha=alpha
+ )
+ axes[2].set_title(title, fontsize=fontsize)
+ axes[2].axis("off")
+
+ # Prediction
+ axes[3].imshow(img)
+ pred_mask = axes[3].imshow( # noqa: F841
+ pred, vmin=vmin, vmax=vmax, cmap=cmap["pred"], alpha=alpha
+ )
+ axes[3].set_title("Predicted mask", fontsize=fontsize)
+ axes[3].axis("off")
+
+ # Labels
+ axes[4].imshow(img)
+ label_mask = axes[4].imshow(label, alpha=alpha) # noqa: F841
+ axes[4].set_title("Labels", fontsize=fontsize)
+ axes[4].axis("off")
+
+ f.savefig(
+ output_filename,
+ dpi=f.dpi,
+ bbox_inches="tight",
+ facecolor="white",
+ transparent=False,
+ )
+ plt.close(f)
+
+
+def load_ground(ground_output_path, ref_image_path):
+ gop = Path(ground_output_path)
+ rip = Path(ref_image_path)
+
+ ground_paths = list((gop / "eval-metrics" / "pred").glob(f"{rip.stem}.jpg")) + list(
+ (gop / "eval-metrics" / "pred").glob(f"{rip.stem}.png")
+ )
+ if len(ground_paths) == 0:
+ raise ValueError(
+ f"Could not find a ground match in {str(gop)} for image {str(rip)}"
+ )
+ elif len(ground_paths) > 1:
+ raise ValueError(
+ f"Found more than 1 ground match in {str(gop)} for image {str(rip)}:"
+ + f" {list(map(str, ground_paths))}"
+ )
+ ground_path = ground_paths[0]
+ _, ground = crop_and_resize(rip, ground_path)
+ if ground.ndim == 3:
+ ground = ground[:, :, 0]
+ ground = (ground > 0).astype(np.float32)
+ return torch.from_numpy(ground).unsqueeze(0).unsqueeze(0).cuda()
+
+
+def get_inferences(
+ image_arrays, model_path, image_paths, paint=False, bin_value=0.5, verbose=0
+):
+ """
+ Obtains the mask predictions of a model for a set of images
+
+ Parameters
+ ----------
+ image_arrays : array-like
+ A list of (1, CH, H, W) images
+
+ image_paths: list(Path)
+ A list of paths for images, in the same order as image_arrays
+
+ model_path : str
+ The path to a pre-trained model
+
+ Returns
+ -------
+ masks : list
+ A list of (H, W) predicted masks
+ """
+ device = torch.device("cuda:0")
+ torch.set_grad_enabled(False)
+ to_tensor = ToTensor()
+
+ is_ground = "ground" in Path(model_path).name
+ is_instagan = "instagan" in Path(model_path).name
+
+ if is_ground or is_instagan:
+ # we just care about he painter here
+ ground_path = model_path
+ model_path = (
+ "/miniscratch/_groups/ccai/experiments/runs/ablation-v1/out--38858350"
+ )
+
+ xs = [to_tensor(array).unsqueeze(0) for array in image_arrays]
+ xs = [x.to(torch.float32).to(device) for x in xs]
+ xs = [(x - 0.5) * 2 for x in xs]
+ trainer = Trainer.resume_from_path(
+ model_path, inference=True, new_exp=None, device=device
+ )
+ masks = []
+ painted = []
+ for idx, x in enumerate(xs):
+ if verbose > 0:
+ print(idx, "/", len(xs), end="\r")
+
+ if not is_ground and not is_instagan:
+ m = trainer.G.mask(x=x)
+ else:
+ m = load_ground(ground_path, image_paths[idx])
+
+ masks.append(m.squeeze().cpu())
+ if paint:
+ p = trainer.G.paint(m > bin_value, x)
+ painted.append(p.squeeze().cpu())
+ return masks, painted
+
+
+if __name__ == "__main__":
+ # -----------------------------
+ # ----- Parse arguments -----
+ # -----------------------------
+ args = parsed_args()
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
+
+ # Determine output dir
+ try:
+ tmp_dir = Path(os.environ["SLURM_TMPDIR"])
+ except Exception as e:
+ print(e)
+ tmp_dir = Path(input("Enter tmp output directory: ")).resolve()
+
+ plot_dir = tmp_dir / "plots"
+ plot_dir.mkdir(parents=True, exist_ok=True)
+
+ # Build paths to data
+ imgs_paths = sorted(
+ find_images(args.images_dir, recursive=False), key=lambda x: x.name
+ )
+ labels_paths = sorted(
+ find_images(args.labels_dir, recursive=False),
+ key=lambda x: x.name.replace("_labeled.", "."),
+ )
+ if args.max_files > 0:
+ imgs_paths = imgs_paths[: args.max_files]
+ labels_paths = labels_paths[: args.max_files]
+
+ print(f"Loading {len(imgs_paths)} images and labels...")
+
+ # Pre-process images: resize + crop
+ # TODO: ? make cropping more flexible, not only central
+ if not args.prepare_torch:
+ ims_labs = [crop_and_resize(i, l) for i, l in zip(imgs_paths, labels_paths)]
+ imgs = [d[0] for d in ims_labs]
+ labels = [d[1] for d in ims_labs]
+ else:
+ prepare = PrepareTest()
+ imgs = prepare(imgs_paths, normalize=False, rescale=False)
+ labels = prepare(labels_paths, normalize=False, rescale=False)
+
+ imgs = [i.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) for i in imgs]
+ labels = [
+ lab.squeeze(0).permute(1, 2, 0).numpy().astype(np.uint8) for lab in labels
+ ]
+ imgs = [rgba2rgb(img) if img.shape[-1] == 4 else img for img in imgs]
+ print(" Done.")
+
+ # Encode labels
+ print("Encode labels...", end="", flush=True)
+ # HW label
+ labels = [np.squeeze(encode_mask_label(label, "flood")) for label in labels]
+ print("Done.")
+
+ if args.yaml:
+ y_path = Path(args.yaml)
+ assert y_path.exists()
+ assert y_path.suffix in {".yaml", ".yml"}
+ with y_path.open("r") as f:
+ data = yaml.safe_load(f)
+ assert "models" in data
+
+ evaluations = [m for m in data["models"]]
+ else:
+ evaluations = [args.model]
+
+ for e, eval_path in enumerate(evaluations):
+ print("\n>>>>> Evaluation", e, ":", eval_path)
+ print("=" * 50)
+ print("=" * 50)
+
+ model_metrics_path = Path(eval_path) / "eval-metrics"
+ model_metrics_path.mkdir(exist_ok=True)
+ if args.load_metrics:
+ f_csv = model_metrics_path / "eval_masker.csv"
+ pred_out = model_metrics_path / "pred"
+ if f_csv.exists() and pred_out.exists():
+ print("Skipping model because pre-computed metrics exist")
+ continue
+
+ # Initialize New Comet Experiment
+ exp = Experiment(
+ project_name="climategan-masker-metrics", display_summary_level=0
+ )
+
+ # Obtain mask predictions
+ # TODO: remove (debug)
+ print("Obtain mask predictions", end="", flush=True)
+
+ preds, painted = get_inferences(
+ imgs,
+ eval_path,
+ imgs_paths,
+ paint=not args.no_paint,
+ bin_value=args.bin_value,
+ verbose=1,
+ )
+ preds = [pred.numpy() for pred in preds]
+ print(" Done.")
+
+ if args.bin_value > 0:
+ preds = [pred > args.bin_value for pred in preds]
+
+ # Compute metrics
+ df = pd.DataFrame(
+ columns=[
+ "tpr",
+ "tpt",
+ "tnr",
+ "tnt",
+ "fpr",
+ "fpt",
+ "fnr",
+ "fnt",
+ "mnr",
+ "mpr",
+ "accuracy",
+ "error",
+ "precision",
+ "f05",
+ "accuracy_must_may",
+ "edge_coherence",
+ "filename",
+ ]
+ )
+
+ print("Compute metrics and plot images")
+ for idx, (img, label, pred) in enumerate(zip(*(imgs, labels, preds))):
+ print(idx, "/", len(imgs), end="\r")
+
+ # Basic classification metrics
+ metrics_dict, maps_dict = masker_classification_metrics(
+ pred, label, labels_dict={"cannot": 0, "must": 1, "may": 2}
+ )
+
+ # Edges coherence
+ edge_coherence, pred_edge, label_edge = edges_coherence_std_min(pred, label)
+
+ series_dict = {
+ "tpr": metrics_dict["tpr"],
+ "tpt": metrics_dict["tpt"],
+ "tnr": metrics_dict["tnr"],
+ "tnt": metrics_dict["tnt"],
+ "fpr": metrics_dict["fpr"],
+ "fpt": metrics_dict["fpt"],
+ "fnr": metrics_dict["fnr"],
+ "fnt": metrics_dict["fnt"],
+ "mnr": metrics_dict["mnr"],
+ "mpr": metrics_dict["mpr"],
+ "accuracy": metrics_dict["accuracy"],
+ "error": metrics_dict["error"],
+ "precision": metrics_dict["precision"],
+ "f05": metrics_dict["f05"],
+ "accuracy_must_may": metrics_dict["accuracy_must_may"],
+ "edge_coherence": edge_coherence,
+ "filename": str(imgs_paths[idx].name),
+ }
+ df.loc[idx] = pd.Series(series_dict)
+
+ for k, v in series_dict.items():
+ if k == "filename":
+ continue
+ exp.log_metric(f"img_{k}", v, step=idx)
+
+ # Confusion matrix
+ confmat, _ = get_confusion_matrix(
+ metrics_dict["tpr"],
+ metrics_dict["tnr"],
+ metrics_dict["fpr"],
+ metrics_dict["fnr"],
+ metrics_dict["mnr"],
+ metrics_dict["mpr"],
+ )
+ confmat = np.around(confmat, decimals=3)
+ exp.log_confusion_matrix(
+ file_name=imgs_paths[idx].name + ".json",
+ title=imgs_paths[idx].name,
+ matrix=confmat,
+ labels=["Cannot", "Must", "May"],
+ row_label="Predicted",
+ column_label="Ground truth",
+ )
+
+ if args.plot:
+ # Plot prediction images
+ fig_filename = plot_dir / imgs_paths[idx].name
+ plot_images(
+ fig_filename,
+ img,
+ label,
+ pred,
+ metrics_dict,
+ maps_dict,
+ edge_coherence,
+ pred_edge,
+ label_edge,
+ )
+ exp.log_image(fig_filename)
+ if not args.no_paint:
+ masked = img * (1 - pred[..., None])
+ flooded = img_as_ubyte(
+ (painted[idx].permute(1, 2, 0).cpu().numpy() + 1) / 2
+ )
+ combined = np.concatenate([img, masked, flooded], 1)
+ exp.log_image(combined, imgs_paths[idx].name)
+
+ if args.write_metrics:
+ pred_out = model_metrics_path / "pred"
+ pred_out.mkdir(exist_ok=True)
+ imsave(
+ pred_out / f"{imgs_paths[idx].stem}_pred.png",
+ pred.astype(np.uint8),
+ )
+ for k, v in maps_dict.items():
+ metric_out = model_metrics_path / k
+ metric_out.mkdir(exist_ok=True)
+ imsave(
+ metric_out / f"{imgs_paths[idx].stem}_{k}.png",
+ v.astype(np.uint8),
+ )
+
+ # --------------------------------
+ # ----- END OF IMAGES LOOP -----
+ # --------------------------------
+
+ if args.write_metrics:
+ print(f"Writing metrics in {str(model_metrics_path)}")
+ f_csv = model_metrics_path / "eval_masker.csv"
+ df.to_csv(f_csv, index_label="idx")
+
+ print(" Done.")
+ # Summary statistics
+ means = df.mean(axis=0)
+ confmat_mean, confmat_std = get_confusion_matrix(
+ df.tpr, df.tnr, df.fpr, df.fnr, df.mpr, df.mnr
+ )
+ confmat_mean = np.around(confmat_mean, decimals=3)
+ confmat_std = np.around(confmat_std, decimals=3)
+
+ # Log to comet
+ exp.log_confusion_matrix(
+ file_name="confusion_matrix_mean.json",
+ title="confusion_matrix_mean.json",
+ matrix=confmat_mean,
+ labels=["Cannot", "Must", "May"],
+ row_label="Predicted",
+ column_label="Ground truth",
+ )
+ exp.log_confusion_matrix(
+ file_name="confusion_matrix_std.json",
+ title="confusion_matrix_std.json",
+ matrix=confmat_std,
+ labels=["Cannot", "Must", "May"],
+ row_label="Predicted",
+ column_label="Ground truth",
+ )
+ exp.log_metrics(dict(means))
+ exp.log_table("metrics.csv", df)
+ exp.log_html(df.to_html(col_space="80px"))
+ exp.log_parameters(vars(args))
+ exp.log_parameter("eval_path", str(eval_path))
+ exp.add_tag("eval_masker")
+ if args.tags:
+ exp.add_tags(args.tags)
+ exp.log_parameter("model_id", Path(eval_path).name)
+
+ # Close comet
+ exp.end()
+
+ # --------------------------------
+ # ----- END OF MODElS LOOP -----
+ # --------------------------------
+
+ # Compare models
+ if (args.load_metrics or args.write_metrics) and len(evaluations) > 1:
+ print(
+ "Plots for comparing the input models will be created and logged to comet"
+ )
+
+ # Initialize New Comet Experiment
+ exp = Experiment(
+ project_name="climategan-masker-metrics", display_summary_level=0
+ )
+ if args.tags:
+ exp.add_tags(args.tags)
+
+ # Build DataFrame with all models
+ print("Building pandas DataFrame...")
+ models_df = {}
+ for (m, model_path) in enumerate(evaluations):
+ model_path = Path(model_path)
+ with open(model_path / "opts.yaml", "r") as f:
+ opt = yaml.safe_load(f)
+ model_feats = ", ".join(
+ [
+ t
+ for t in sorted(opt["comet"]["tags"])
+ if "branch" not in t and "ablation" not in t and "trash" not in t
+ ]
+ )
+ model_id = f"{model_path.parent.name[-2:]}/{model_path.name}"
+ df_m = pd.read_csv(
+ model_path / "eval-metrics" / "eval_masker.csv", index_col=False
+ )
+ df_m["model"] = [model_id] * len(df_m)
+ df_m["model_idx"] = [m] * len(df_m)
+ df_m["model_feats"] = [model_feats] * len(df_m)
+ models_df.update({model_id: df_m})
+ df = pd.concat(list(models_df.values()), ignore_index=True)
+ df["model_img_idx"] = df.model.astype(str) + "-" + df.idx.astype(str)
+ df.rename(columns={"idx": "img_idx"}, inplace=True)
+ dict_models_labels = {
+ k: f"{v['model_idx'][0]}: {v['model_feats'][0]}"
+ for k, v in models_df.items()
+ }
+ print("Done")
+
+ if args.output_csv:
+ print(f"Writing DataFrame to {args.output_csv}")
+ df.to_csv(args.output_csv, index_label="model_img_idx")
+
+ # Determine images with low metrics in any model
+ print("Constructing filter based on metrics thresholds...")
+ idx_not_good_in_any = []
+ for idx in df.img_idx.unique():
+ df_th = df.loc[
+ (
+ # TODO: rethink thresholds
+ (df.tpr <= dict_metrics["threshold"]["tpr"])
+ | (df.fpr >= dict_metrics["threshold"]["fpr"])
+ | (df.edge_coherence >= dict_metrics["threshold"]["edge_coherence"])
+ )
+ & ((df.img_idx == idx) & (df.model.isin(df.model.unique())))
+ ]
+ if len(df_th) > 0:
+ idx_not_good_in_any.append(idx)
+ filters = {"all": df.img_idx.unique(), "not_good_in_any": idx_not_good_in_any}
+ print("Done")
+
+ # Boxplots of metrics
+ print("Plotting boxplots of metrics...")
+ for k, f in filters.items():
+ print(f"\tDistribution of [{k}] images...")
+ for metric in dict_metrics["names"].keys():
+ fig_filename = plot_dir / f"boxplot_{metric}_{k}.png"
+ if metric in ["mnr", "mpr", "accuracy_must_may"]:
+ boxplot_metric(
+ fig_filename,
+ df.loc[df.img_idx.isin(f)],
+ metric=metric,
+ dict_metrics=dict_metrics["names"],
+ do_stripplot=True,
+ dict_models=dict_models_labels,
+ order=list(df.model.unique()),
+ )
+ else:
+ boxplot_metric(
+ fig_filename,
+ df.loc[df.img_idx.isin(f)],
+ metric=metric,
+ dict_metrics=dict_metrics["names"],
+ dict_models=dict_models_labels,
+ fliersize=1.0,
+ order=list(df.model.unique()),
+ )
+ exp.log_image(fig_filename)
+ print("Done")
+
+ # Cluster Maps
+ print("Plotting clustermaps...")
+ for k, f in filters.items():
+ print(f"\tDistribution of [{k}] images...")
+ for metric in dict_metrics["names"].keys():
+ fig_filename = plot_dir / f"clustermap_{metric}_{k}.png"
+ df_mf = df.loc[df.img_idx.isin(f)].pivot("img_idx", "model", metric)
+ clustermap_metric(
+ output_filename=fig_filename,
+ df=df_mf,
+ metric=metric,
+ dict_metrics=dict_metrics["names"],
+ method="average",
+ cluster_metric="euclidean",
+ dict_models=dict_models_labels,
+ row_cluster=False,
+ )
+ exp.log_image(fig_filename)
+ print("Done")
+
+ # Close comet
+ exp.end()
diff --git a/figures/ablation_comparison.py b/figures/ablation_comparison.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6bae9b40885f26ae1e4e98d7f0a7f15cfe64a5d
--- /dev/null
+++ b/figures/ablation_comparison.py
@@ -0,0 +1,394 @@
+"""
+This script evaluates the contribution of a technique from the ablation study for
+improving the masker evaluation metrics. The differences in the metrics are computed
+for all images of paired models, that is those which only differ in the inclusion or
+not of the given technique. Then, statistical inference is performed through the
+percentile bootstrap to obtain robust estimates of the differences in the metrics and
+confidence intervals. The script plots the distribution of the bootrstraped estimates.
+"""
+print("Imports...", end="")
+from argparse import ArgumentParser
+import yaml
+import numpy as np
+import pandas as pd
+import seaborn as sns
+import os
+from pathlib import Path
+import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
+import matplotlib.transforms as transforms
+
+
+# -----------------------
+# ----- Constants -----
+# -----------------------
+
+dict_models = {
+ "md": 11,
+ "dada_ms, msd, pseudo": 9,
+ "msd, pseudo": 4,
+ "dada, msd_spade, pseudo": 7,
+ "msd": 13,
+ "dada_m, msd": 17,
+ "dada, msd_spade": 16,
+ "msd_spade, pseudo": 5,
+ "dada_ms, msd": 18,
+ "dada, msd, pseudo": 6,
+ "ms": 12,
+ "dada, msd": 15,
+ "dada_m, msd, pseudo": 8,
+ "msd_spade": 14,
+ "m": 10,
+ "md, pseudo": 2,
+ "ms, pseudo": 3,
+ "m, pseudo": 1,
+ "ground": "G",
+ "instagan": "I",
+}
+
+dict_metrics = {
+ "names": {
+ "tpr": "TPR, Recall, Sensitivity",
+ "tnr": "TNR, Specificity, Selectivity",
+ "fpr": "FPR",
+ "fpt": "False positives relative to image size",
+ "fnr": "FNR, Miss rate",
+ "fnt": "False negatives relative to image size",
+ "mpr": "May positive rate (MPR)",
+ "mnr": "May negative rate (MNR)",
+ "accuracy": "Accuracy (ignoring may)",
+ "error": "Error",
+ "f05": "F05 score",
+ "precision": "Precision",
+ "edge_coherence": "Edge coherence",
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
+ },
+ "key_metrics": ["f05", "error", "edge_coherence"],
+}
+dict_techniques = {
+ "depth": "depth",
+ "segmentation": "seg",
+ "seg": "seg",
+ "dada_s": "dada_seg",
+ "dada_seg": "dada_seg",
+ "dada_segmentation": "dada_seg",
+ "dada_m": "dada_masker",
+ "dada_masker": "dada_masker",
+ "spade": "spade",
+ "pseudo": "pseudo",
+ "pseudo-labels": "pseudo",
+ "pseudo_labels": "pseudo",
+}
+
+# Markers
+dict_markers = {"error": "o", "f05": "s", "edge_coherence": "^"}
+
+# Model features
+model_feats = [
+ "masker",
+ "seg",
+ "depth",
+ "dada_seg",
+ "dada_masker",
+ "spade",
+ "pseudo",
+ "ground",
+ "instagan",
+]
+
+# Colors
+palette_colorblind = sns.color_palette("colorblind")
+color_climategan = palette_colorblind[0]
+color_munit = palette_colorblind[1]
+color_cyclegan = palette_colorblind[6]
+color_instagan = palette_colorblind[8]
+color_maskinstagan = palette_colorblind[2]
+color_paintedground = palette_colorblind[3]
+
+color_cat1 = palette_colorblind[0]
+color_cat2 = palette_colorblind[1]
+palette_lightest = [
+ sns.light_palette(color_cat1, n_colors=20)[3],
+ sns.light_palette(color_cat2, n_colors=20)[3],
+]
+palette_light = [
+ sns.light_palette(color_cat1, n_colors=3)[1],
+ sns.light_palette(color_cat2, n_colors=3)[1],
+]
+palette_medium = [color_cat1, color_cat2]
+palette_dark = [
+ sns.dark_palette(color_cat1, n_colors=3)[1],
+ sns.dark_palette(color_cat2, n_colors=3)[1],
+]
+palette_cat1 = [
+ palette_lightest[0],
+ palette_light[0],
+ palette_medium[0],
+ palette_dark[0],
+]
+palette_cat2 = [
+ palette_lightest[1],
+ palette_light[1],
+ palette_medium[1],
+ palette_dark[1],
+]
+color_cat1_light = palette_light[0]
+color_cat2_light = palette_light[1]
+
+
+def parsed_args():
+ """
+ Parse and returns command-line args
+
+ Returns:
+ argparse.Namespace: the parsed arguments
+ """
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--input_csv",
+ default="ablations_metrics_20210311.csv",
+ type=str,
+ help="CSV containing the results of the ablation study",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ help="Output directory",
+ )
+ parser.add_argument(
+ "--models",
+ default="all",
+ type=str,
+ help="Models to display: all, pseudo, no_dada_masker, no_baseline",
+ )
+ parser.add_argument(
+ "--dpi",
+ default=200,
+ type=int,
+ help="DPI for the output images",
+ )
+ parser.add_argument(
+ "--n_bs",
+ default=1e6,
+ type=int,
+ help="Number of bootrstrap samples",
+ )
+ parser.add_argument(
+ "--alpha",
+ default=0.99,
+ type=float,
+ help="Confidence level",
+ )
+ parser.add_argument(
+ "--bs_seed",
+ default=17,
+ type=int,
+ help="Bootstrap random seed, for reproducibility",
+ )
+
+ return parser.parse_args()
+
+
+def plot_median_metrics(
+ df, do_stripplot=True, dpi=200, bs_seed=37, n_bs=1000, **snskwargs
+):
+ def plot_metric(
+ ax, df, metric, do_stripplot=True, dpi=200, bs_seed=37, marker="o", **snskwargs
+ ):
+
+ y_labels = [dict_models[f] for f in df.model_feats.unique()]
+
+ # Labels
+ y_labels_int = np.sort([el for el in y_labels if isinstance(el, int)]).tolist()
+ y_order_int = [
+ k for vs in y_labels_int for k, vu in dict_models.items() if vs == vu
+ ]
+ y_labels_int = [str(el) for el in y_labels_int]
+
+ y_labels_str = sorted([el for el in y_labels if not isinstance(el, int)])
+ y_order_str = [
+ k for vs in y_labels_str for k, vu in dict_models.items() if vs == vu
+ ]
+ y_labels = y_labels_int + y_labels_str
+ y_order = y_order_int + y_order_str
+
+ # Palette
+ palette = len(y_labels_int) * [color_climategan]
+ for y in y_labels_str:
+ if y == "G":
+ palette = palette + [color_paintedground]
+ if y == "I":
+ palette = palette + [color_maskinstagan]
+
+ # Error
+ sns.pointplot(
+ ax=ax,
+ data=df,
+ x=metric,
+ y="model_feats",
+ order=y_order,
+ markers=marker,
+ estimator=np.median,
+ ci=99,
+ seed=bs_seed,
+ n_boot=n_bs,
+ join=False,
+ scale=0.6,
+ errwidth=1.5,
+ capsize=0.1,
+ palette=palette,
+ )
+ xlim = ax.get_xlim()
+
+ if do_stripplot:
+ sns.stripplot(
+ ax=ax,
+ data=df,
+ x=metric,
+ y="model_feats",
+ size=1.5,
+ palette=palette,
+ alpha=0.2,
+ )
+ ax.set_xlim(xlim)
+
+ # Set X-label
+ ax.set_xlabel(dict_metrics["names"][metric], rotation=0, fontsize="medium")
+
+ # Set Y-label
+ ax.set_ylabel(None)
+
+ ax.set_yticklabels(y_labels, fontsize="medium")
+
+ # Change spines
+ sns.despine(ax=ax, left=True, bottom=True)
+
+ # Draw gray area on final model
+ xlim = ax.get_xlim()
+ ylim = ax.get_ylim()
+ trans = transforms.blended_transform_factory(ax.transAxes, ax.transData)
+ rect = mpatches.Rectangle(
+ xy=(0.0, 5.5),
+ width=1,
+ height=1,
+ transform=trans,
+ linewidth=0.0,
+ edgecolor="none",
+ facecolor="gray",
+ alpha=0.05,
+ )
+ ax.add_patch(rect)
+
+ # Set up plot
+ sns.set(style="whitegrid")
+ plt.rcParams.update({"font.family": "serif"})
+ plt.rcParams.update(
+ {
+ "font.serif": [
+ "Computer Modern Roman",
+ "Times New Roman",
+ "Utopia",
+ "New Century Schoolbook",
+ "Century Schoolbook L",
+ "ITC Bookman",
+ "Bookman",
+ "Times",
+ "Palatino",
+ "Charter",
+ "serif" "Bitstream Vera Serif",
+ "DejaVu Serif",
+ ]
+ }
+ )
+
+ fig_h = 0.4 * len(df.model_feats.unique())
+ fig, axes = plt.subplots(
+ nrows=1, ncols=3, sharey=True, dpi=dpi, figsize=(18, fig_h)
+ )
+
+ # Error
+ plot_metric(
+ axes[0],
+ df,
+ "error",
+ do_stripplot=do_stripplot,
+ dpi=dpi,
+ bs_seed=bs_seed,
+ marker=dict_markers["error"],
+ )
+ axes[0].set_ylabel("Models")
+
+ # F05
+ plot_metric(
+ axes[1],
+ df,
+ "f05",
+ do_stripplot=do_stripplot,
+ dpi=dpi,
+ bs_seed=bs_seed,
+ marker=dict_markers["f05"],
+ )
+
+ # Edge coherence
+ plot_metric(
+ axes[2],
+ df,
+ "edge_coherence",
+ do_stripplot=do_stripplot,
+ dpi=dpi,
+ bs_seed=bs_seed,
+ marker=dict_markers["edge_coherence"],
+ )
+ xticks = axes[2].get_xticks()
+ xticklabels = ["{:.3f}".format(x) for x in xticks]
+ axes[2].set(xticks=xticks, xticklabels=xticklabels)
+
+ plt.subplots_adjust(wspace=0.12)
+
+ return fig
+
+
+if __name__ == "__main__":
+ # -----------------------------
+ # ----- Parse arguments -----
+ # -----------------------------
+ args = parsed_args()
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
+
+ # Determine output dir
+ if args.output_dir is None:
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
+ else:
+ output_dir = Path(args.output_dir)
+ if not output_dir.exists():
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ # Store args
+ output_yml = output_dir / "ablation_comparison_{}.yml".format(args.models)
+ with open(output_yml, "w") as f:
+ yaml.dump(vars(args), f)
+
+ # Read CSV
+ df = pd.read_csv(args.input_csv, index_col="model_img_idx")
+
+ # Determine models
+ if "all" in args.models.lower():
+ pass
+ else:
+ if "no_baseline" in args.models.lower():
+ df = df.loc[(df.ground == False) & (df.instagan == False)]
+ if "pseudo" in args.models.lower():
+ df = df.loc[
+ (df.pseudo == True) | (df.ground == True) | (df.instagan == True)
+ ]
+ if "no_dada_mask" in args.models.lower():
+ df = df.loc[
+ (df.dada_masker == False) | (df.ground == True) | (df.instagan == True)
+ ]
+
+ fig = plot_median_metrics(df, do_stripplot=True, dpi=args.dpi, bs_seed=args.bs_seed)
+
+ # Save figure
+ output_fig = output_dir / "ablation_comparison_{}.png".format(args.models)
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
diff --git a/figures/bootstrap_ablation.py b/figures/bootstrap_ablation.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c7f4e876e4543323e17b6e13c478d0e9ff47a98
--- /dev/null
+++ b/figures/bootstrap_ablation.py
@@ -0,0 +1,562 @@
+"""
+This script evaluates the contribution of a technique from the ablation study for
+improving the masker evaluation metrics. The differences in the metrics are computed
+for all images of paired models, that is those which only differ in the inclusion or
+not of the given technique. Then, statistical inference is performed through the
+percentile bootstrap to obtain robust estimates of the differences in the metrics and
+confidence intervals. The script plots the distribution of the bootrstraped estimates.
+"""
+print("Imports...", end="")
+from argparse import ArgumentParser
+import yaml
+import os
+import numpy as np
+import pandas as pd
+import seaborn as sns
+from scipy.stats import trim_mean
+from tqdm import tqdm
+from pathlib import Path
+import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
+
+
+# -----------------------
+# ----- Constants -----
+# -----------------------
+
+dict_metrics = {
+ "names": {
+ "tpr": "TPR, Recall, Sensitivity",
+ "tnr": "TNR, Specificity, Selectivity",
+ "fpr": "FPR",
+ "fpt": "False positives relative to image size",
+ "fnr": "FNR, Miss rate",
+ "fnt": "False negatives relative to image size",
+ "mpr": "May positive rate (MPR)",
+ "mnr": "May negative rate (MNR)",
+ "accuracy": "Accuracy (ignoring may)",
+ "error": "Error",
+ "f05": "F05 score",
+ "precision": "Precision",
+ "edge_coherence": "Edge coherence",
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
+ },
+ "key_metrics": ["f05", "error", "edge_coherence"],
+}
+dict_techniques = {
+ "depth": "depth",
+ "segmentation": "seg",
+ "seg": "seg",
+ "dada_s": "dada_seg",
+ "dada_seg": "dada_seg",
+ "dada_segmentation": "dada_seg",
+ "dada_m": "dada_masker",
+ "dada_masker": "dada_masker",
+ "spade": "spade",
+ "pseudo": "pseudo",
+ "pseudo-labels": "pseudo",
+ "pseudo_labels": "pseudo",
+}
+
+# Model features
+model_feats = [
+ "masker",
+ "seg",
+ "depth",
+ "dada_seg",
+ "dada_masker",
+ "spade",
+ "pseudo",
+ "ground",
+ "instagan",
+]
+
+# Colors
+palette_colorblind = sns.color_palette("colorblind")
+color_cat1 = palette_colorblind[0]
+color_cat2 = palette_colorblind[1]
+palette_lightest = [
+ sns.light_palette(color_cat1, n_colors=20)[3],
+ sns.light_palette(color_cat2, n_colors=20)[3],
+]
+palette_light = [
+ sns.light_palette(color_cat1, n_colors=3)[1],
+ sns.light_palette(color_cat2, n_colors=3)[1],
+]
+palette_medium = [color_cat1, color_cat2]
+palette_dark = [
+ sns.dark_palette(color_cat1, n_colors=3)[1],
+ sns.dark_palette(color_cat2, n_colors=3)[1],
+]
+palette_cat1 = [
+ palette_lightest[0],
+ palette_light[0],
+ palette_medium[0],
+ palette_dark[0],
+]
+palette_cat2 = [
+ palette_lightest[1],
+ palette_light[1],
+ palette_medium[1],
+ palette_dark[1],
+]
+color_cat1_light = palette_light[0]
+color_cat2_light = palette_light[1]
+
+
+def parsed_args():
+ """
+ Parse and returns command-line args
+
+ Returns:
+ argparse.Namespace: the parsed arguments
+ """
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--input_csv",
+ default="ablations_metrics_20210311.csv",
+ type=str,
+ help="CSV containing the results of the ablation study",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ help="Output directory",
+ )
+ parser.add_argument(
+ "--technique",
+ default=None,
+ type=str,
+ help="Keyword specifying the technique. One of: pseudo, depth, segmentation, dada_seg, dada_masker, spade",
+ )
+ parser.add_argument(
+ "--dpi",
+ default=200,
+ type=int,
+ help="DPI for the output images",
+ )
+ parser.add_argument(
+ "--n_bs",
+ default=1e6,
+ type=int,
+ help="Number of bootrstrap samples",
+ )
+ parser.add_argument(
+ "--alpha",
+ default=0.99,
+ type=float,
+ help="Confidence level",
+ )
+ parser.add_argument(
+ "--bs_seed",
+ default=17,
+ type=int,
+ help="Bootstrap random seed, for reproducibility",
+ )
+
+ return parser.parse_args()
+
+
+def add_ci_mean(
+ ax, sample_measure, bs_mean, bs_std, ci, color, alpha, fontsize, invert=False
+):
+
+ # Fill area between CI
+ dist = ax.lines[0]
+ dist_y = dist.get_ydata()
+ dist_x = dist.get_xdata()
+ linewidth = dist.get_linewidth()
+
+ x_idx_low = np.argmin(np.abs(dist_x - ci[0]))
+ x_idx_high = np.argmin(np.abs(dist_x - ci[1]))
+ x_ci = dist_x[x_idx_low:x_idx_high]
+ y_ci = dist_y[x_idx_low:x_idx_high]
+
+ ax.fill_between(x_ci, 0, y_ci, facecolor=color, alpha=alpha)
+
+ # Add vertical lines of CI
+ ax.vlines(
+ x=ci[0],
+ ymin=0.0,
+ ymax=y_ci[0],
+ color=color,
+ linewidth=linewidth,
+ label="ci_low",
+ )
+ ax.vlines(
+ x=ci[1],
+ ymin=0.0,
+ ymax=y_ci[-1],
+ color=color,
+ linewidth=linewidth,
+ label="ci_high",
+ )
+
+ # Add annotations
+ bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2)
+
+ if invert:
+ ha_l = "right"
+ ha_u = "left"
+ else:
+ ha_l = "left"
+ ha_u = "right"
+ ax.text(
+ ci[0],
+ 0.0,
+ s="L = {:.4f}".format(ci[0]),
+ ha=ha_l,
+ va="bottom",
+ fontsize=fontsize,
+ bbox=bbox_props,
+ )
+ ax.text(
+ ci[1],
+ 0.0,
+ s="U = {:.4f}".format(ci[1]),
+ ha=ha_u,
+ va="bottom",
+ fontsize=fontsize,
+ bbox=bbox_props,
+ )
+
+ # Add vertical line of bootstrap mean
+ x_idx_mean = np.argmin(np.abs(dist_x - bs_mean))
+ ax.vlines(
+ x=bs_mean, ymin=0.0, ymax=dist_y[x_idx_mean], color="k", linewidth=linewidth
+ )
+
+ # Add annotation of bootstrap mean
+ bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2)
+
+ ax.text(
+ bs_mean,
+ 0.6 * dist_y[x_idx_mean],
+ s="Bootstrap mean = {:.4f}".format(bs_mean),
+ ha="center",
+ va="center",
+ fontsize=fontsize,
+ bbox=bbox_props,
+ )
+
+ # Add vertical line of sample_measure
+ x_idx_smeas = np.argmin(np.abs(dist_x - sample_measure))
+ ax.vlines(
+ x=sample_measure,
+ ymin=0.0,
+ ymax=dist_y[x_idx_smeas],
+ color="k",
+ linewidth=linewidth,
+ linestyles="dotted",
+ )
+
+ # Add SD
+ bbox_props = dict(boxstyle="darrow, pad=0.4", fc="w", ec="k", lw=2)
+
+ ax.text(
+ bs_mean,
+ 0.4 * dist_y[x_idx_mean],
+ s="SD = {:.4f} = SE".format(bs_std),
+ ha="center",
+ va="center",
+ fontsize=fontsize,
+ bbox=bbox_props,
+ )
+
+
+def add_null_pval(ax, null, color, alpha, fontsize):
+
+ # Fill area between CI
+ dist = ax.lines[0]
+ dist_y = dist.get_ydata()
+ dist_x = dist.get_xdata()
+ linewidth = dist.get_linewidth()
+
+ x_idx_null = np.argmin(np.abs(dist_x - null))
+ if x_idx_null >= (len(dist_x) / 2.0):
+ x_pval = dist_x[x_idx_null:]
+ y_pval = dist_y[x_idx_null:]
+ else:
+ x_pval = dist_x[:x_idx_null]
+ y_pval = dist_y[:x_idx_null]
+
+ ax.fill_between(x_pval, 0, y_pval, facecolor=color, alpha=alpha)
+
+ # Add vertical lines of null
+ dist = ax.lines[0]
+ linewidth = dist.get_linewidth()
+ y_max = ax.get_ylim()[1]
+ ax.vlines(
+ x=null,
+ ymin=0.0,
+ ymax=y_max,
+ color="k",
+ linewidth=linewidth,
+ linestyles="dotted",
+ )
+
+ # Add annotations
+ bbox_props = dict(boxstyle="round, pad=0.4", fc="w", ec="k", lw=2)
+
+ ax.text(
+ null,
+ 0.75 * y_max,
+ s="Null hypothesis = {:.1f}".format(null),
+ ha="center",
+ va="center",
+ fontsize=fontsize,
+ bbox=bbox_props,
+ )
+
+
+def plot_bootstrap_distr(
+ sample_measure, bs_samples, alpha, color_ci, color_pval=None, null=None
+):
+
+ # Compute results from bootstrap
+ q_low = (1.0 - alpha) / 2.0
+ q_high = 1.0 - q_low
+ ci = np.quantile(bs_samples, [q_low, q_high])
+ bs_mean = np.mean(bs_samples)
+ bs_std = np.std(bs_samples)
+
+ if null is not None and color_pval is not None:
+ pval_flag = True
+ pval = np.min([[np.mean(bs_samples > null), np.mean(bs_samples < null)]]) * 2
+ else:
+ pval_flag = False
+
+ # Set up plot
+ sns.set(style="whitegrid")
+ fontsize = 24
+ font = {"family": "DejaVu Sans", "weight": "normal", "size": fontsize}
+ plt.rc("font", **font)
+ alpha_plot = 0.5
+
+ # Initialize the matplotlib figure
+ fig, ax = plt.subplots(figsize=(30, 12), dpi=args.dpi)
+
+ # Plot distribution of bootstrap means
+ sns.kdeplot(bs_samples, color="b", linewidth=5, gridsize=1000, ax=ax)
+
+ y_lim = ax.get_ylim()
+
+ # Change spines
+ sns.despine(left=True, bottom=True)
+
+ # Annotations
+ add_ci_mean(
+ ax,
+ sample_measure,
+ bs_mean,
+ bs_std,
+ ci,
+ color=color_ci,
+ alpha=alpha_plot,
+ fontsize=fontsize,
+ )
+
+ if pval_flag:
+ add_null_pval(ax, null, color=color_pval, alpha=alpha_plot, fontsize=fontsize)
+
+ # Legend
+ ci_patch = mpatches.Patch(
+ facecolor=color_ci,
+ edgecolor=None,
+ alpha=alpha_plot,
+ label="{:d} % confidence interval".format(int(100 * alpha)),
+ )
+
+ if pval_flag:
+ if pval == 0.0:
+ pval_patch = mpatches.Patch(
+ facecolor=color_pval,
+ edgecolor=None,
+ alpha=alpha_plot,
+ label="P value / 2 = {:.1f}".format(pval / 2.0),
+ )
+ elif np.around(pval / 2.0, decimals=4) > 0.0000:
+ pval_patch = mpatches.Patch(
+ facecolor=color_pval,
+ edgecolor=None,
+ alpha=alpha_plot,
+ label="P value / 2 = {:.4f}".format(pval / 2.0),
+ )
+ else:
+ pval_patch = mpatches.Patch(
+ facecolor=color_pval,
+ edgecolor=None,
+ alpha=alpha_plot,
+ label="P value / 2 < $10^{}$".format(np.ceil(np.log10(pval / 2.0))),
+ )
+
+ leg = ax.legend(
+ handles=[ci_patch, pval_patch],
+ ncol=1,
+ loc="upper right",
+ frameon=True,
+ framealpha=1.0,
+ title="",
+ fontsize=fontsize,
+ columnspacing=1.0,
+ labelspacing=0.2,
+ markerfirst=True,
+ )
+ else:
+ leg = ax.legend(
+ handles=[ci_patch],
+ ncol=1,
+ loc="upper right",
+ frameon=True,
+ framealpha=1.0,
+ title="",
+ fontsize=fontsize,
+ columnspacing=1.0,
+ labelspacing=0.2,
+ markerfirst=True,
+ )
+
+ plt.setp(leg.get_title(), fontsize=fontsize, horizontalalignment="left")
+
+ # Set X-label
+ ax.set_xlabel("Bootstrap estimates", rotation=0, fontsize=fontsize, labelpad=10.0)
+
+ # Set Y-label
+ ax.set_ylabel("Density", rotation=90, fontsize=fontsize, labelpad=10.0)
+
+ # Ticks
+ plt.setp(ax.get_xticklabels(), fontsize=0.8 * fontsize, verticalalignment="top")
+ plt.setp(ax.get_yticklabels(), fontsize=0.8 * fontsize)
+
+ ax.set_ylim(y_lim)
+
+ return fig, bs_mean, bs_std, ci, pval
+
+
+if __name__ == "__main__":
+ # -----------------------------
+ # ----- Parse arguments -----
+ # -----------------------------
+ args = parsed_args()
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
+
+ # Determine output dir
+ if args.output_dir is None:
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
+ else:
+ output_dir = Path(args.output_dir)
+ if not output_dir.exists():
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ # Store args
+ output_yml = output_dir / "{}_bootstrap.yml".format(args.technique)
+ with open(output_yml, "w") as f:
+ yaml.dump(vars(args), f)
+
+ # Determine technique
+ if args.technique.lower() not in dict_techniques:
+ raise ValueError("{} is not a valid technique".format(args.technique))
+ else:
+ technique = dict_techniques[args.technique.lower()]
+
+ # Read CSV
+ df = pd.read_csv(args.input_csv, index_col="model_img_idx")
+
+ # Find relevant model pairs
+ model_pairs = []
+ for mi in df.loc[df[technique]].model_feats.unique():
+ for mj in df.model_feats.unique():
+ if mj == mi:
+ continue
+
+ if df.loc[df.model_feats == mj, technique].unique()[0]:
+ continue
+
+ is_pair = True
+ for f in model_feats:
+ if f == technique:
+ continue
+ elif (
+ df.loc[df.model_feats == mj, f].unique()[0]
+ != df.loc[df.model_feats == mi, f].unique()[0]
+ ):
+ is_pair = False
+ break
+ else:
+ pass
+ if is_pair:
+ model_pairs.append((mi, mj))
+ break
+
+ print("\nModel pairs identified:\n")
+ for pair in model_pairs:
+ print("{} & {}".format(pair[0], pair[1]))
+
+ df["base"] = ["N/A"] * len(df)
+ for spp in model_pairs:
+ df.loc[df.model_feats.isin(spp), "depth_base"] = spp[1]
+
+ # Build bootstrap data
+ data = {m: [] for m in dict_metrics["key_metrics"]}
+ for m_with, m_without in model_pairs:
+ df_with = df.loc[df.model_feats == m_with]
+ df_without = df.loc[df.model_feats == m_without]
+ for metric in data.keys():
+ diff = (
+ df_with.sort_values(by="img_idx")[metric].values
+ - df_without.sort_values(by="img_idx")[metric].values
+ )
+ data[metric].extend(diff.tolist())
+
+ # Run bootstrap
+ measures = ["mean", "median", "20_trimmed_mean"]
+ bs_data = {meas: {m: np.zeros(args.n_bs) for m in data.keys()} for meas in measures}
+
+ np.random.seed(args.bs_seed)
+ for m, data_m in data.items():
+ for idx, s in enumerate(tqdm(range(args.n_bs))):
+ # Sample with replacement
+ bs_sample = np.random.choice(data_m, size=len(data_m), replace=True)
+
+ # Store mean
+ bs_data["mean"][m][idx] = np.mean(bs_sample)
+
+ # Store median
+ bs_data["median"][m][idx] = np.median(bs_sample)
+
+ # Store 20 % trimmed mean
+ bs_data["20_trimmed_mean"][m][idx] = trim_mean(bs_sample, 0.2)
+
+for metric in dict_metrics["key_metrics"]:
+ sample_measure = trim_mean(data[metric], 0.2)
+ fig, bs_mean, bs_std, ci, pval = plot_bootstrap_distr(
+ sample_measure,
+ bs_data["20_trimmed_mean"][metric],
+ alpha=args.alpha,
+ color_ci=color_cat1_light,
+ color_pval=color_cat2_light,
+ null=0.0,
+ )
+
+ # Save figure
+ output_fig = output_dir / "{}_bootstrap_{}_{}.png".format(
+ args.technique, metric, "20_trimmed_mean"
+ )
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
+
+ # Store results
+ output_results = output_dir / "{}_bootstrap_{}_{}.yml".format(
+ args.technique, metric, "20_trimmed_mean"
+ )
+ results_dict = {
+ "measure": "20_trimmed_mean",
+ "sample_measure": float(sample_measure),
+ "bs_mean": float(bs_mean),
+ "bs_std": float(bs_std),
+ "ci_left": float(ci[0]),
+ "ci_right": float(ci[1]),
+ "pval": float(pval),
+ }
+ with open(output_results, "w") as f:
+ yaml.dump(results_dict, f)
diff --git a/figures/bootstrap_ablation_summary.py b/figures/bootstrap_ablation_summary.py
new file mode 100644
index 0000000000000000000000000000000000000000..e64a7b86d737a1a2ce422b2f14850d7f00169e23
--- /dev/null
+++ b/figures/bootstrap_ablation_summary.py
@@ -0,0 +1,361 @@
+"""
+This script computes the median difference and confidence intervals of all techniques from the ablation study for
+improving the masker evaluation metrics. The differences in the metrics are computed
+for all images of paired models, that is those which only differ in the inclusion or
+not of the given technique. Then, statistical inference is performed through the
+percentile bootstrap to obtain robust estimates of the differences in the metrics and
+confidence intervals. The script plots the summary for all techniques.
+"""
+print("Imports...", end="")
+from argparse import ArgumentParser
+import yaml
+import numpy as np
+import pandas as pd
+import seaborn as sns
+from scipy.special import comb
+from scipy.stats import trim_mean
+from tqdm import tqdm
+from collections import OrderedDict
+from pathlib import Path
+import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
+import matplotlib.transforms as transforms
+
+
+# -----------------------
+# ----- Constants -----
+# -----------------------
+
+dict_metrics = {
+ "names": {
+ "tpr": "TPR, Recall, Sensitivity",
+ "tnr": "TNR, Specificity, Selectivity",
+ "fpr": "FPR",
+ "fpt": "False positives relative to image size",
+ "fnr": "FNR, Miss rate",
+ "fnt": "False negatives relative to image size",
+ "mpr": "May positive rate (MPR)",
+ "mnr": "May negative rate (MNR)",
+ "accuracy": "Accuracy (ignoring may)",
+ "error": "Error",
+ "f05": "F05 score",
+ "precision": "Precision",
+ "edge_coherence": "Edge coherence",
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
+ },
+ "key_metrics": ["error", "f05", "edge_coherence"],
+}
+
+dict_techniques = OrderedDict(
+ [
+ ("pseudo", "Pseudo labels"),
+ ("depth", "Depth (D)"),
+ ("seg", "Seg. (S)"),
+ ("spade", "SPADE"),
+ ("dada_seg", "DADA (S)"),
+ ("dada_masker", "DADA (M)"),
+ ]
+)
+
+# Model features
+model_feats = [
+ "masker",
+ "seg",
+ "depth",
+ "dada_seg",
+ "dada_masker",
+ "spade",
+ "pseudo",
+ "ground",
+ "instagan",
+]
+
+# Colors
+crest = sns.color_palette("crest", as_cmap=False, n_colors=7)
+palette_metrics = [crest[0], crest[3], crest[6]]
+sns.palplot(palette_metrics)
+
+# Markers
+dict_markers = {"error": "o", "f05": "s", "edge_coherence": "^"}
+
+
+def parsed_args():
+ """
+ Parse and returns command-line args
+
+ Returns:
+ argparse.Namespace: the parsed arguments
+ """
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--input_csv",
+ default="ablations_metrics_20210311.csv",
+ type=str,
+ help="CSV containing the results of the ablation study",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ help="Output directory",
+ )
+ parser.add_argument(
+ "--dpi",
+ default=200,
+ type=int,
+ help="DPI for the output images",
+ )
+ parser.add_argument(
+ "--n_bs",
+ default=1e6,
+ type=int,
+ help="Number of bootrstrap samples",
+ )
+ parser.add_argument(
+ "--alpha",
+ default=0.99,
+ type=float,
+ help="Confidence level",
+ )
+ parser.add_argument(
+ "--bs_seed",
+ default=17,
+ type=int,
+ help="Bootstrap random seed, for reproducibility",
+ )
+
+ return parser.parse_args()
+
+
+def trim_mean_wrapper(a):
+ return trim_mean(a, proportiontocut=0.2)
+
+
+def find_model_pairs(technique, model_feats):
+ model_pairs = []
+ for mi in df.loc[df[technique]].model_feats.unique():
+ for mj in df.model_feats.unique():
+ if mj == mi:
+ continue
+
+ if df.loc[df.model_feats == mj, technique].unique()[0]:
+ continue
+
+ is_pair = True
+ for f in model_feats:
+ if f == technique:
+ continue
+ elif (
+ df.loc[df.model_feats == mj, f].unique()[0]
+ != df.loc[df.model_feats == mi, f].unique()[0]
+ ):
+ is_pair = False
+ break
+ else:
+ pass
+ if is_pair:
+ model_pairs.append((mi, mj))
+ break
+ return model_pairs
+
+
+if __name__ == "__main__":
+ # -----------------------------
+ # ----- Parse arguments -----
+ # -----------------------------
+ args = parsed_args()
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
+
+ # Determine output dir
+ if args.output_dir is None:
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
+ else:
+ output_dir = Path(args.output_dir)
+ if not output_dir.exists():
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ # Store args
+ output_yml = output_dir / "bootstrap_summary.yml"
+ with open(output_yml, "w") as f:
+ yaml.dump(vars(args), f)
+
+ # Read CSV
+ df = pd.read_csv(args.input_csv, index_col="model_img_idx")
+
+ # Build data set
+ dfbs = pd.DataFrame(columns=["diff", "technique", "metric"])
+ for technique in model_feats:
+
+ # Get pairs
+ model_pairs = find_model_pairs(technique, model_feats)
+
+ # Compute differences
+ for m_with, m_without in model_pairs:
+ df_with = df.loc[df.model_feats == m_with]
+ df_without = df.loc[df.model_feats == m_without]
+ for metric in dict_metrics["key_metrics"]:
+ diff = (
+ df_with.sort_values(by="img_idx")[metric].values
+ - df_without.sort_values(by="img_idx")[metric].values
+ )
+ dfm = pd.DataFrame.from_dict(
+ {"metric": metric, "technique": technique, "diff": diff}
+ )
+ dfbs = dfbs.append(dfm, ignore_index=True)
+
+ ### Plot
+
+ # Set up plot
+ sns.reset_orig()
+ sns.set(style="whitegrid")
+ plt.rcParams.update({"font.family": "serif"})
+ plt.rcParams.update(
+ {
+ "font.serif": [
+ "Computer Modern Roman",
+ "Times New Roman",
+ "Utopia",
+ "New Century Schoolbook",
+ "Century Schoolbook L",
+ "ITC Bookman",
+ "Bookman",
+ "Times",
+ "Palatino",
+ "Charter",
+ "serif" "Bitstream Vera Serif",
+ "DejaVu Serif",
+ ]
+ }
+ )
+
+ fig, axes = plt.subplots(
+ nrows=1, ncols=3, sharey=True, dpi=args.dpi, figsize=(9, 3)
+ )
+
+ metrics = ["error", "f05", "edge_coherence"]
+ dict_ci = {m: {} for m in metrics}
+
+ for idx, metric in enumerate(dict_metrics["key_metrics"]):
+
+ ax = sns.pointplot(
+ ax=axes[idx],
+ data=dfbs.loc[dfbs.metric.isin(["error", "f05", "edge_coherence"])],
+ order=dict_techniques.keys(),
+ x="diff",
+ y="technique",
+ hue="metric",
+ hue_order=[metric],
+ markers=dict_markers[metric],
+ palette=[palette_metrics[idx]],
+ errwidth=1.5,
+ scale=0.6,
+ join=False,
+ estimator=trim_mean_wrapper,
+ ci=int(args.alpha * 100),
+ n_boot=args.n_bs,
+ seed=args.bs_seed,
+ )
+
+ # Retrieve confidence intervals and update results dictionary
+ for line, technique in zip(ax.lines, dict_techniques.keys()):
+ dict_ci[metric].update(
+ {
+ technique: {
+ "20_trimmed_mean": float(
+ trim_mean_wrapper(
+ dfbs.loc[
+ (dfbs.technique == technique)
+ & (dfbs.metric == metrics[idx]),
+ "diff",
+ ].values
+ )
+ ),
+ "ci_left": float(line.get_xdata()[0]),
+ "ci_right": float(line.get_xdata()[1]),
+ }
+ }
+ )
+
+ leg_handles, leg_labels = ax.get_legend_handles_labels()
+
+ # Change spines
+ sns.despine(left=True, bottom=True)
+
+ # Set Y-label
+ ax.set_ylabel(None)
+
+ # Y-tick labels
+ ax.set_yticklabels(list(dict_techniques.values()), fontsize="medium")
+
+ # Set X-label
+ ax.set_xlabel(None)
+
+ # X-ticks
+ xticks = ax.get_xticks()
+ xticklabels = xticks
+ ax.set_xticks(xticks)
+ ax.set_xticklabels(xticklabels, fontsize="small")
+
+ # Y-lim
+ display2data = ax.transData.inverted()
+ ax2display = ax.transAxes
+ _, y_bottom = display2data.transform(ax.transAxes.transform((0.0, 0.02)))
+ _, y_top = display2data.transform(ax.transAxes.transform((0.0, 0.98)))
+ ax.set_ylim(bottom=y_bottom, top=y_top)
+
+ # Draw line at H0
+ y = np.arange(ax.get_ylim()[1], ax.get_ylim()[0], 0.1)
+ x = 0.0 * np.ones(y.shape[0])
+ ax.plot(x, y, linestyle=":", linewidth=1.5, color="black")
+
+ # Draw gray area
+ xlim = ax.get_xlim()
+ ylim = ax.get_ylim()
+ if metric == "error":
+ x0 = xlim[0]
+ width = np.abs(x0)
+ else:
+ x0 = 0.0
+ width = np.abs(xlim[1])
+ trans = transforms.blended_transform_factory(ax.transData, ax.transAxes)
+ rect = mpatches.Rectangle(
+ xy=(x0, 0.0),
+ width=width,
+ height=1,
+ transform=trans,
+ linewidth=0.0,
+ edgecolor="none",
+ facecolor="gray",
+ alpha=0.05,
+ )
+ ax.add_patch(rect)
+
+ # Legend
+ leg_handles, leg_labels = ax.get_legend_handles_labels()
+ leg_labels = [dict_metrics["names"][metric] for metric in leg_labels]
+ leg = ax.legend(
+ handles=leg_handles,
+ labels=leg_labels,
+ loc="center",
+ title="",
+ bbox_to_anchor=(-0.2, 1.05, 1.0, 0.0),
+ framealpha=1.0,
+ frameon=False,
+ handletextpad=-0.2,
+ )
+
+ # Set X-label (title) │
+ fig.suptitle(
+ "20 % trimmed mean difference and bootstrapped confidence intervals",
+ y=0.0,
+ fontsize="medium",
+ )
+
+ # Save figure
+ output_fig = output_dir / "bootstrap_summary.png"
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
+
+ # Store results
+ output_results = output_dir / "bootstrap_summary_results.yml"
+ with open(output_results, "w") as f:
+ yaml.dump(dict_ci, f)
diff --git a/figures/human_evaluation.py b/figures/human_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..2889c0a945879830b844259f203612f96f759bef
--- /dev/null
+++ b/figures/human_evaluation.py
@@ -0,0 +1,208 @@
+"""
+This script plots the result of the human evaluation on Amazon Mechanical Turk, where
+human participants chose between an image from ClimateGAN or from a different method.
+"""
+print("Imports...", end="")
+from argparse import ArgumentParser
+import os
+import yaml
+import numpy as np
+import pandas as pd
+import seaborn as sns
+from pathlib import Path
+import matplotlib.pyplot as plt
+
+
+# -----------------------
+# ----- Constants -----
+# -----------------------
+
+comparables_dict = {
+ "munit_flooded": "MUNIT",
+ "cyclegan": "CycleGAN",
+ "instagan": "InstaGAN",
+ "instagan_copypaste": "Mask-InstaGAN",
+ "painted_ground": "Painted ground",
+}
+
+
+# Colors
+palette_colorblind = sns.color_palette("colorblind")
+color_climategan = palette_colorblind[9]
+
+palette_colorblind = sns.color_palette("colorblind")
+color_munit = palette_colorblind[1]
+color_cyclegan = palette_colorblind[2]
+color_instagan = palette_colorblind[3]
+color_maskinstagan = palette_colorblind[6]
+color_paintedground = palette_colorblind[8]
+palette_comparables = [
+ color_munit,
+ color_cyclegan,
+ color_instagan,
+ color_maskinstagan,
+ color_paintedground,
+]
+palette_comparables_light = [
+ sns.light_palette(color, n_colors=3)[1] for color in palette_comparables
+]
+
+
+def parsed_args():
+ """
+ Parse and returns command-line args
+
+ Returns:
+ argparse.Namespace: the parsed arguments
+ """
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--input_csv",
+ default="amt_omni-vs-other.csv",
+ type=str,
+ help="CSV containing the results of the human evaluation, pre-processed",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ help="Output directory",
+ )
+ parser.add_argument(
+ "--dpi",
+ default=200,
+ type=int,
+ help="DPI for the output images",
+ )
+ parser.add_argument(
+ "--n_bs",
+ default=1e6,
+ type=int,
+ help="Number of bootrstrap samples",
+ )
+ parser.add_argument(
+ "--bs_seed",
+ default=17,
+ type=int,
+ help="Bootstrap random seed, for reproducibility",
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ # -----------------------------
+ # ----- Parse arguments -----
+ # -----------------------------
+ args = parsed_args()
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
+
+ # Determine output dir
+ if args.output_dir is None:
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
+ else:
+ output_dir = Path(args.output_dir)
+ if not output_dir.exists():
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ # Store args
+ output_yml = output_dir / "args_human_evaluation.yml"
+ with open(output_yml, "w") as f:
+ yaml.dump(vars(args), f)
+
+ # Read CSV
+ df = pd.read_csv(args.input_csv)
+
+ # Sort Y labels
+ comparables = df.comparable.unique()
+ is_climategan_sum = [
+ df.loc[df.comparable == c, "climategan"].sum() for c in comparables
+ ]
+ comparables = comparables[np.argsort(is_climategan_sum)[::-1]]
+
+ # Plot setup
+ sns.set(style="whitegrid")
+ plt.rcParams.update({"font.family": "serif"})
+ plt.rcParams.update(
+ {
+ "font.serif": [
+ "Computer Modern Roman",
+ "Times New Roman",
+ "Utopia",
+ "New Century Schoolbook",
+ "Century Schoolbook L",
+ "ITC Bookman",
+ "Bookman",
+ "Times",
+ "Palatino",
+ "Charter",
+ "serif" "Bitstream Vera Serif",
+ "DejaVu Serif",
+ ]
+ }
+ )
+ fontsize = "medium"
+
+ # Initialize the matplotlib figure
+ fig, ax = plt.subplots(figsize=(10.5, 3), dpi=args.dpi)
+
+ # Plot the total (right)
+ sns.barplot(
+ data=df.loc[df.is_valid],
+ x="is_valid",
+ y="comparable",
+ order=comparables,
+ orient="h",
+ label="comparable",
+ palette=palette_comparables_light,
+ ci=None,
+ )
+
+ # Plot the left
+ sns.barplot(
+ data=df.loc[df.is_valid],
+ x="climategan",
+ y="comparable",
+ order=comparables,
+ orient="h",
+ label="climategan",
+ color=color_climategan,
+ ci=99,
+ n_boot=args.n_bs,
+ seed=args.bs_seed,
+ errcolor="black",
+ errwidth=1.5,
+ capsize=0.1,
+ )
+
+ # Draw line at 0.5
+ y = np.arange(ax.get_ylim()[1] + 0.1, ax.get_ylim()[0], 0.1)
+ x = 0.5 * np.ones(y.shape[0])
+ ax.plot(x, y, linestyle=":", linewidth=1.5, color="black")
+
+ # Change Y-Tick labels
+ yticklabels = [comparables_dict[ytick.get_text()] for ytick in ax.get_yticklabels()]
+ yticklabels_text = ax.set_yticklabels(
+ yticklabels, fontsize=fontsize, horizontalalignment="right", x=0.96
+ )
+ for ytl in yticklabels_text:
+ ax.add_artist(ytl)
+
+ # Remove Y-label
+ ax.set_ylabel(ylabel="")
+
+ # Change X-Tick labels
+ xlim = [0.0, 1.1]
+ xticks = np.arange(xlim[0], xlim[1], 0.1)
+ ax.set(xticks=xticks)
+ plt.setp(ax.get_xticklabels(), fontsize=fontsize)
+
+ # Set X-label
+ ax.set_xlabel(None)
+
+ # Change spines
+ sns.despine(left=True, bottom=True)
+
+ # Save figure
+ output_fig = output_dir / "human_evaluation_rate_climategan.png"
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
diff --git a/figures/labels.py b/figures/labels.py
new file mode 100644
index 0000000000000000000000000000000000000000..d60cb11def6277c8913e36bff3f91f744865b679
--- /dev/null
+++ b/figures/labels.py
@@ -0,0 +1,200 @@
+"""
+This scripts plots images from the Masker test set overlaid with their labels.
+"""
+print("Imports...", end="")
+from argparse import ArgumentParser
+import os
+import yaml
+import numpy as np
+import pandas as pd
+import seaborn as sns
+from pathlib import Path
+import matplotlib.pyplot as plt
+import matplotlib.patches as mpatches
+
+import sys
+
+sys.path.append("../")
+
+from eval_masker import crop_and_resize
+
+
+# -----------------------
+# ----- Constants -----
+# -----------------------
+
+# Colors
+colorblind_palette = sns.color_palette("colorblind")
+color_cannot = colorblind_palette[1]
+color_must = colorblind_palette[2]
+color_may = colorblind_palette[7]
+color_pred = colorblind_palette[4]
+
+icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5)
+color_tp = icefire[0]
+color_tn = icefire[1]
+color_fp = icefire[4]
+color_fn = icefire[3]
+
+
+def parsed_args():
+ """
+ Parse and returns command-line args
+
+ Returns:
+ argparse.Namespace: the parsed arguments
+ """
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--input_csv",
+ default="ablations_metrics_20210311.csv",
+ type=str,
+ help="CSV containing the results of the ablation study",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ help="Output directory",
+ )
+ parser.add_argument(
+ "--masker_test_set_dir",
+ default=None,
+ type=str,
+ help="Directory containing the test images",
+ )
+ parser.add_argument(
+ "--images",
+ nargs="+",
+ help="List of image file names to plot",
+ default=[],
+ type=str,
+ )
+ parser.add_argument(
+ "--dpi",
+ default=200,
+ type=int,
+ help="DPI for the output images",
+ )
+ parser.add_argument(
+ "--alpha",
+ default=0.5,
+ type=float,
+ help="Transparency of labels shade",
+ )
+
+ return parser.parse_args()
+
+
+def map_color(arr, input_color, output_color, rtol=1e-09):
+ """
+ Maps one color to another
+ """
+ input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,)))
+ output = arr.copy()
+ output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color
+ return output
+
+
+if __name__ == "__main__":
+ # -----------------------------
+ # ----- Parse arguments -----
+ # -----------------------------
+ args = parsed_args()
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
+
+ # Determine output dir
+ if args.output_dir is None:
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
+ else:
+ output_dir = Path(args.output_dir)
+ if not output_dir.exists():
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ # Store args
+ output_yml = output_dir / "labels.yml"
+ with open(output_yml, "w") as f:
+ yaml.dump(vars(args), f)
+
+ # Data dirs
+ imgs_orig_path = Path(args.masker_test_set_dir) / "imgs"
+ labels_path = Path(args.masker_test_set_dir) / "labels"
+
+ # Read CSV
+ df = pd.read_csv(args.input_csv, index_col="model_img_idx")
+
+ # Set up plot
+ sns.reset_orig()
+ sns.set(style="whitegrid")
+ plt.rcParams.update({"font.family": "serif"})
+ plt.rcParams.update(
+ {
+ "font.serif": [
+ "Computer Modern Roman",
+ "Times New Roman",
+ "Utopia",
+ "New Century Schoolbook",
+ "Century Schoolbook L",
+ "ITC Bookman",
+ "Bookman",
+ "Times",
+ "Palatino",
+ "Charter",
+ "serif" "Bitstream Vera Serif",
+ "DejaVu Serif",
+ ]
+ }
+ )
+
+ fig, axes = plt.subplots(
+ nrows=1, ncols=len(args.images), dpi=args.dpi, figsize=(len(args.images) * 5, 5)
+ )
+
+ for idx, img_filename in enumerate(args.images):
+
+ # Read images
+ img_path = imgs_orig_path / img_filename
+ label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem)
+ img, label = crop_and_resize(img_path, label_path)
+
+ # Map label colors
+ label_colmap = label.astype(float)
+ label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot)
+ label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
+ label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
+
+ ax = axes[idx]
+ ax.imshow(img)
+ ax.imshow(label_colmap, alpha=args.alpha)
+ ax.axis("off")
+
+ # Legend
+ handles = []
+ lw = 1.0
+ handles.append(
+ mpatches.Patch(
+ facecolor=color_must, label="must", linewidth=lw, alpha=args.alpha
+ )
+ )
+ handles.append(
+ mpatches.Patch(facecolor=color_may, label="may", linewidth=lw, alpha=args.alpha)
+ )
+ handles.append(
+ mpatches.Patch(
+ facecolor=color_cannot, label="cannot", linewidth=lw, alpha=args.alpha
+ )
+ )
+ labels = ["Must-be-flooded", "May-be-flooded", "Cannot-be-flooded"]
+ fig.legend(
+ handles=handles,
+ labels=labels,
+ loc="upper center",
+ bbox_to_anchor=(0.0, 0.85, 1.0, 0.075),
+ ncol=len(args.images),
+ fontsize="medium",
+ frameon=False,
+ )
+
+ # Save figure
+ output_fig = output_dir / "labels.png"
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
diff --git a/figures/metrics.py b/figures/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b165eeeb3eb6bf975dd91211dbf6349590156ad
--- /dev/null
+++ b/figures/metrics.py
@@ -0,0 +1,676 @@
+"""
+This scripts plots examples of the images that get best and worse metrics
+"""
+print("Imports...", end="")
+import os
+import sys
+from argparse import ArgumentParser
+from pathlib import Path
+
+import matplotlib.patches as mpatches
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import seaborn as sns
+import yaml
+from imageio import imread
+from skimage.color import rgba2rgb
+from sklearn.metrics.pairwise import euclidean_distances
+
+sys.path.append("../")
+
+from climategan.data import encode_mask_label
+from climategan.eval_metrics import edges_coherence_std_min
+from eval_masker import crop_and_resize
+
+# -----------------------
+# ----- Constants -----
+# -----------------------
+
+# Metrics
+metrics = ["error", "f05", "edge_coherence"]
+
+dict_metrics = {
+ "names": {
+ "tpr": "TPR, Recall, Sensitivity",
+ "tnr": "TNR, Specificity, Selectivity",
+ "fpr": "FPR",
+ "fpt": "False positives relative to image size",
+ "fnr": "FNR, Miss rate",
+ "fnt": "False negatives relative to image size",
+ "mpr": "May positive rate (MPR)",
+ "mnr": "May negative rate (MNR)",
+ "accuracy": "Accuracy (ignoring may)",
+ "error": "Error",
+ "f05": "F05 score",
+ "precision": "Precision",
+ "edge_coherence": "Edge coherence",
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
+ },
+ "key_metrics": ["error", "f05", "edge_coherence"],
+}
+
+
+# Colors
+colorblind_palette = sns.color_palette("colorblind")
+color_cannot = colorblind_palette[1]
+color_must = colorblind_palette[2]
+color_may = colorblind_palette[7]
+color_pred = colorblind_palette[4]
+
+icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5)
+color_tp = icefire[0]
+color_tn = icefire[1]
+color_fp = icefire[4]
+color_fn = icefire[3]
+
+
+def parsed_args():
+ """
+ Parse and returns command-line args
+
+ Returns:
+ argparse.Namespace: the parsed arguments
+ """
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--input_csv",
+ default="ablations_metrics_20210311.csv",
+ type=str,
+ help="CSV containing the results of the ablation study",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ help="Output directory",
+ )
+ parser.add_argument(
+ "--models_log_path",
+ default=None,
+ type=str,
+ help="Path containing the log files of the models",
+ )
+ parser.add_argument(
+ "--masker_test_set_dir",
+ default=None,
+ type=str,
+ help="Directory containing the test images",
+ )
+ parser.add_argument(
+ "--best_model",
+ default="dada, msd_spade, pseudo",
+ type=str,
+ help="The string identifier of the best model",
+ )
+ parser.add_argument(
+ "--dpi",
+ default=200,
+ type=int,
+ help="DPI for the output images",
+ )
+ parser.add_argument(
+ "--alpha",
+ default=0.5,
+ type=float,
+ help="Transparency of labels shade",
+ )
+ parser.add_argument(
+ "--percentile",
+ default=0.05,
+ type=float,
+ help="Transparency of labels shade",
+ )
+ parser.add_argument(
+ "--seed",
+ default=None,
+ type=int,
+ help="Bootstrap random seed, for reproducibility",
+ )
+ parser.add_argument(
+ "--no_images",
+ action="store_true",
+ default=False,
+ help="Do not generate images",
+ )
+
+ return parser.parse_args()
+
+
+def map_color(arr, input_color, output_color, rtol=1e-09):
+ """
+ Maps one color to another
+ """
+ input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,)))
+ output = arr.copy()
+ output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color
+ return output
+
+
+def plot_labels(ax, img, label, img_id, do_legend):
+ label_colmap = label.astype(float)
+ label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot)
+ label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
+ label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
+
+ ax.imshow(img)
+ ax.imshow(label_colmap, alpha=0.5)
+ ax.axis("off")
+
+ # Annotation
+ ax.annotate(
+ xy=(0.05, 0.95),
+ xycoords="axes fraction",
+ xytext=(0.05, 0.95),
+ textcoords="axes fraction",
+ text=img_id,
+ fontsize="x-large",
+ verticalalignment="top",
+ color="white",
+ )
+
+ # Legend
+ if do_legend:
+ handles = []
+ lw = 1.0
+ handles.append(
+ mpatches.Patch(facecolor=color_must, label="must", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(facecolor=color_may, label="must", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(
+ facecolor=color_cannot, label="must", linewidth=lw, alpha=0.66
+ )
+ )
+ labels = ["Must-be-flooded", "May-be-flooded", "Cannot-be-flooded"]
+ ax.legend(
+ handles=handles,
+ labels=labels,
+ bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
+ ncol=3,
+ mode="expand",
+ fontsize="xx-small",
+ frameon=False,
+ )
+
+
+def plot_pred(ax, img, pred, img_id, do_legend):
+ pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
+
+ pred_colmap = pred.astype(float)
+ pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
+ pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
+ pred_colmap_ma = pred_colmap_ma.mask * img + pred_colmap_ma
+
+ ax.imshow(img)
+ ax.imshow(pred_colmap_ma, alpha=0.5)
+ ax.axis("off")
+
+ # Annotation
+ ax.annotate(
+ xy=(0.05, 0.95),
+ xycoords="axes fraction",
+ xytext=(0.05, 0.95),
+ textcoords="axes fraction",
+ text=img_id,
+ fontsize="x-large",
+ verticalalignment="top",
+ color="white",
+ )
+
+ # Legend
+ if do_legend:
+ handles = []
+ lw = 1.0
+ handles.append(
+ mpatches.Patch(facecolor=color_pred, label="must", linewidth=lw, alpha=0.66)
+ )
+ labels = ["Prediction"]
+ ax.legend(
+ handles=handles,
+ labels=labels,
+ bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
+ ncol=3,
+ mode="expand",
+ fontsize="xx-small",
+ frameon=False,
+ )
+
+
+def plot_correct_incorrect(ax, img_filename, img, label, img_id, do_legend):
+ # FP
+ fp_map = imread(
+ model_path / "eval-metrics/fp" / "{}_fp.png".format(Path(img_filename).stem)
+ )
+ fp_map = np.tile(np.expand_dims(fp_map, axis=2), reps=(1, 1, 3))
+
+ fp_map_colmap = fp_map.astype(float)
+ fp_map_colmap = map_color(fp_map_colmap, (1, 1, 1), color_fp)
+
+ # FN
+ fn_map = imread(
+ model_path / "eval-metrics/fn" / "{}_fn.png".format(Path(img_filename).stem)
+ )
+ fn_map = np.tile(np.expand_dims(fn_map, axis=2), reps=(1, 1, 3))
+
+ fn_map_colmap = fn_map.astype(float)
+ fn_map_colmap = map_color(fn_map_colmap, (1, 1, 1), color_fn)
+
+ # TP
+ tp_map = imread(
+ model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(img_filename).stem)
+ )
+ tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
+
+ tp_map_colmap = tp_map.astype(float)
+ tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
+
+ # TN
+ tn_map = imread(
+ model_path / "eval-metrics/tn" / "{}_tn.png".format(Path(img_filename).stem)
+ )
+ tn_map = np.tile(np.expand_dims(tn_map, axis=2), reps=(1, 1, 3))
+
+ tn_map_colmap = tn_map.astype(float)
+ tn_map_colmap = map_color(tn_map_colmap, (1, 1, 1), color_tn)
+
+ label_colmap = label.astype(float)
+ label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
+ label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_may)
+ label_colmap_ma = label_colmap_ma.mask * img + label_colmap_ma
+
+ # Combine masks
+ maps = fp_map_colmap + fn_map_colmap + tp_map_colmap + tn_map_colmap
+ maps_ma = np.ma.masked_equal(maps, (0, 0, 0))
+ maps_ma = maps_ma.mask * img + maps_ma
+
+ ax.imshow(img)
+ ax.imshow(label_colmap_ma, alpha=0.5)
+ ax.imshow(maps_ma, alpha=0.5)
+ ax.axis("off")
+
+ # Annotation
+ ax.annotate(
+ xy=(0.05, 0.95),
+ xycoords="axes fraction",
+ xytext=(0.05, 0.95),
+ textcoords="axes fraction",
+ text=img_id,
+ fontsize="x-large",
+ verticalalignment="top",
+ color="white",
+ )
+
+ # Legend
+ if do_legend:
+ handles = []
+ lw = 1.0
+ handles.append(
+ mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(facecolor=color_tn, label="TN", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(facecolor=color_fp, label="FP", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(facecolor=color_fn, label="FN", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(
+ facecolor=color_may, label="May-be-flooded", linewidth=lw, alpha=0.66
+ )
+ )
+ labels = ["TP", "TN", "FP", "FN", "May-be-flooded"]
+ ax.legend(
+ handles=handles,
+ labels=labels,
+ bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
+ ncol=5,
+ mode="expand",
+ fontsize="xx-small",
+ frameon=False,
+ )
+
+
+def plot_edge_coherence(ax, img, label, pred, img_id, do_legend):
+ pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
+
+ ec, pred_ec, label_ec = edges_coherence_std_min(
+ np.squeeze(pred[:, :, 0]), np.squeeze(encode_mask_label(label, "flood"))
+ )
+
+ ##################
+ # Edge distances #
+ ##################
+
+ # Location of edges
+ pred_ec_coord = np.argwhere(pred_ec > 0)
+ label_ec_coord = np.argwhere(label_ec > 0)
+
+ # Normalized pairwise distances between pred and label
+ dist_mat = np.divide(
+ euclidean_distances(pred_ec_coord, label_ec_coord), pred_ec.shape[0]
+ )
+
+ # Standard deviation of the minimum distance from pred to label
+ min_dist = np.min(dist_mat, axis=1) # noqa: F841
+
+ #############
+ # Make plot #
+ #############
+
+ pred_ec = np.tile(
+ np.expand_dims(np.asarray(pred_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
+ )
+ pred_ec_colmap = map_color(pred_ec, (1, 1, 1), color_pred)
+ pred_ec_colmap_ma = np.ma.masked_not_equal(pred_ec_colmap, color_pred) # noqa: F841
+
+ label_ec = np.tile(
+ np.expand_dims(np.asarray(label_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
+ )
+ label_ec_colmap = map_color(label_ec, (1, 1, 1), color_must)
+ label_ec_colmap_ma = np.ma.masked_not_equal( # noqa: F841
+ label_ec_colmap, color_must
+ )
+
+ # Combined pred and label edges
+ combined_ec = pred_ec_colmap + label_ec_colmap
+ combined_ec_ma = np.ma.masked_equal(combined_ec, (0, 0, 0))
+ combined_ec_img = combined_ec_ma.mask * img + combined_ec
+
+ # Pred
+ pred_colmap = pred.astype(float)
+ pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
+ pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
+
+ # Must
+ label_colmap = label.astype(float)
+ label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
+ label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_must)
+
+ # TP
+ tp_map = imread(
+ model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(srs_sel.filename).stem)
+ )
+ tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
+ tp_map_colmap = tp_map.astype(float)
+ tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
+ tp_map_colmap_ma = np.ma.masked_not_equal(tp_map_colmap, color_tp)
+
+ # Combination
+ comb_pred = (
+ (pred_colmap_ma.mask ^ tp_map_colmap_ma.mask)
+ & tp_map_colmap_ma.mask
+ & combined_ec_ma.mask
+ ) * pred_colmap
+ comb_label = (
+ (label_colmap_ma.mask ^ pred_colmap_ma.mask)
+ & pred_colmap_ma.mask
+ & combined_ec_ma.mask
+ ) * label_colmap
+ comb_tp = combined_ec_ma.mask * tp_map_colmap.copy()
+ combined = comb_tp + comb_label + comb_pred
+ combined_ma = np.ma.masked_equal(combined, (0, 0, 0))
+ combined_ma = combined_ma.mask * combined_ec_img + combined_ma
+
+ ax.imshow(combined_ec_img, alpha=1)
+ ax.imshow(combined_ma, alpha=0.5)
+ ax.axis("off")
+
+ # Plot lines
+ idx_sort_x = np.argsort(pred_ec_coord[:, 1])
+ offset = 100
+ for idx in range(offset, pred_ec_coord.shape[0], offset):
+ y0, x0 = pred_ec_coord[idx_sort_x[idx], :]
+ argmin = np.argmin(dist_mat[idx_sort_x[idx]])
+ y1, x1 = label_ec_coord[argmin, :]
+ ax.plot([x0, x1], [y0, y1], color="white", linewidth=0.5)
+
+ # Annotation
+ ax.annotate(
+ xy=(0.05, 0.95),
+ xycoords="axes fraction",
+ xytext=(0.05, 0.95),
+ textcoords="axes fraction",
+ text=img_id,
+ fontsize="x-large",
+ verticalalignment="top",
+ color="white",
+ )
+ # Legend
+ if do_legend:
+ handles = []
+ lw = 1.0
+ handles.append(
+ mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(facecolor=color_pred, label="pred", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(
+ facecolor=color_must, label="Must-be-flooded", linewidth=lw, alpha=0.66
+ )
+ )
+ labels = ["TP", "Prediction", "Must-be-flooded"]
+ ax.legend(
+ handles=handles,
+ labels=labels,
+ bbox_to_anchor=(0.0, 1.0, 1.0, 0.075),
+ ncol=3,
+ mode="expand",
+ fontsize="xx-small",
+ frameon=False,
+ )
+
+
+def plot_images_metric(axes, metric, img_filename, img_id, do_legend):
+
+ # Read images
+ img_path = imgs_orig_path / img_filename
+ label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem)
+ img, label = crop_and_resize(img_path, label_path)
+ img = rgba2rgb(img) if img.shape[-1] == 4 else img / 255.0
+ pred = imread(
+ model_path / "eval-metrics/pred" / "{}_pred.png".format(Path(img_filename).stem)
+ )
+
+ # Label
+ plot_labels(axes[0], img, label, img_id, do_legend)
+
+ # Prediction
+ plot_pred(axes[1], img, pred, img_id, do_legend)
+
+ # Correct / incorrect
+ if metric in ["error", "f05"]:
+ plot_correct_incorrect(axes[2], img_filename, img, label, img_id, do_legend)
+ # Edge coherence
+ elif metric == "edge_coherence":
+ plot_edge_coherence(axes[2], img, label, pred, img_id, do_legend)
+ else:
+ raise ValueError
+
+
+def scatterplot_metrics_pair(ax, df, x_metric, y_metric, dict_images):
+
+ sns.scatterplot(data=df, x=x_metric, y=y_metric, ax=ax)
+
+ # Set X-label
+ ax.set_xlabel(dict_metrics["names"][x_metric], rotation=0, fontsize="medium")
+
+ # Set Y-label
+ ax.set_ylabel(dict_metrics["names"][y_metric], rotation=90, fontsize="medium")
+
+ # Change spines
+ sns.despine(ax=ax, left=True, bottom=True)
+
+ annotate_scatterplot(ax, dict_images, x_metric, y_metric)
+
+
+def scatterplot_metrics(ax, df, dict_images):
+
+ sns.scatterplot(data=df, x="error", y="f05", hue="edge_coherence", ax=ax)
+
+ # Set X-label
+ ax.set_xlabel(dict_metrics["names"]["error"], rotation=0, fontsize="medium")
+
+ # Set Y-label
+ ax.set_ylabel(dict_metrics["names"]["f05"], rotation=90, fontsize="medium")
+
+ annotate_scatterplot(ax, dict_images, "error", "f05")
+
+ # Change spines
+ sns.despine(ax=ax, left=True, bottom=True)
+
+ # Set XY limits
+ xlim = ax.get_xlim()
+ ylim = ax.get_ylim()
+ ax.set_xlim([0.0, xlim[1]])
+ ax.set_ylim([ylim[0], 1.0])
+
+
+def annotate_scatterplot(ax, dict_images, x_metric, y_metric, offset=0.1):
+ xlim = ax.get_xlim()
+ ylim = ax.get_ylim()
+ x_len = xlim[1] - xlim[0]
+ y_len = ylim[1] - ylim[0]
+ x_th = xlim[1] - x_len / 2.0
+ y_th = ylim[1] - y_len / 2.0
+ for text, d in dict_images.items():
+ x = d[x_metric]
+ y = d[y_metric]
+ x_text = x + x_len * offset if x < x_th else x - x_len * offset
+ y_text = y + y_len * offset if y < y_th else y - y_len * offset
+ ax.annotate(
+ xy=(x, y),
+ xycoords="data",
+ xytext=(x_text, y_text),
+ textcoords="data",
+ text=text,
+ arrowprops=dict(facecolor="black", shrink=0.05),
+ fontsize="medium",
+ color="black",
+ )
+
+
+if __name__ == "__main__":
+ # -----------------------------
+ # ----- Parse arguments -----
+ # -----------------------------
+ args = parsed_args()
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
+
+ # Determine output dir
+ if args.output_dir is None:
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
+ else:
+ output_dir = Path(args.output_dir)
+ if not output_dir.exists():
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ # Store args
+ output_yml = output_dir / "labels.yml"
+ with open(output_yml, "w") as f:
+ yaml.dump(vars(args), f)
+
+ # Data dirs
+ imgs_orig_path = Path(args.masker_test_set_dir) / "imgs"
+ labels_path = Path(args.masker_test_set_dir) / "labels"
+
+ # Read CSV
+ df = pd.read_csv(args.input_csv, index_col="model_img_idx")
+
+ # Select best model
+ df = df.loc[df.model_feats == args.best_model]
+ v_key, model_dir = df.model.unique()[0].split("/")
+ model_path = Path(args.models_log_path) / "ablation-{}".format(v_key) / model_dir
+
+ # Set up plot
+ sns.reset_orig()
+ sns.set(style="whitegrid")
+ plt.rcParams.update({"font.family": "serif"})
+ plt.rcParams.update(
+ {
+ "font.serif": [
+ "Computer Modern Roman",
+ "Times New Roman",
+ "Utopia",
+ "New Century Schoolbook",
+ "Century Schoolbook L",
+ "ITC Bookman",
+ "Bookman",
+ "Times",
+ "Palatino",
+ "Charter",
+ "serif" "Bitstream Vera Serif",
+ "DejaVu Serif",
+ ]
+ }
+ )
+
+ if args.seed:
+ np.random.seed(args.seed)
+ img_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ dict_images = {}
+ idx = 0
+ for metric in metrics:
+
+ fig, axes = plt.subplots(nrows=2, ncols=3, dpi=200, figsize=(18, 12))
+
+ # Select best
+ if metric == "error":
+ ascending = True
+ else:
+ ascending = False
+ idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
+ srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
+ img_id = img_ids[idx]
+ dict_images.update({img_id: srs_sel})
+
+ # Read images
+ img_filename = srs_sel.filename
+
+ if not args.no_images:
+ axes_row = axes[0, :]
+ plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=True)
+
+ idx += 1
+
+ # Select worst
+ if metric == "error":
+ ascending = False
+ else:
+ ascending = True
+ idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
+ srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
+ img_id = img_ids[idx]
+ dict_images.update({img_id: srs_sel})
+
+ # Read images
+ img_filename = srs_sel.filename
+
+ if not args.no_images:
+ axes_row = axes[1, :]
+ plot_images_metric(axes_row, metric, img_filename, img_id, do_legend=False)
+
+ idx += 1
+
+ # Save figure
+ output_fig = output_dir / "{}.png".format(metric)
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
+
+ fig = plt.figure(dpi=200)
+ scatterplot_metrics(fig.gca(), df, dict_images)
+
+ # fig, axes = plt.subplots(nrows=1, ncols=3, dpi=200, figsize=(18, 5))
+ #
+ # scatterplot_metrics_pair(axes[0], df, 'error', 'f05', dict_images)
+ # scatterplot_metrics_pair(axes[1], df, 'error', 'edge_coherence', dict_images)
+ # scatterplot_metrics_pair(axes[2], df, 'f05', 'edge_coherence', dict_images)
+ #
+ output_fig = output_dir / "scatterplots.png"
+ fig.savefig(output_fig, dpi=fig.dpi, bbox_inches="tight")
diff --git a/figures/metrics_onefig.py b/figures/metrics_onefig.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9d372dcbb1bed2fffbfd8e81d6da749ceab730b
--- /dev/null
+++ b/figures/metrics_onefig.py
@@ -0,0 +1,772 @@
+"""
+This scripts plots examples of the images that get best and worse metrics
+"""
+print("Imports...", end="")
+import os
+import sys
+from argparse import ArgumentParser
+from pathlib import Path
+
+import matplotlib.patches as mpatches
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import seaborn as sns
+import yaml
+from imageio import imread
+from matplotlib.gridspec import GridSpec
+from skimage.color import rgba2rgb
+from sklearn.metrics.pairwise import euclidean_distances
+
+sys.path.append("../")
+
+from climategan.data import encode_mask_label
+from climategan.eval_metrics import edges_coherence_std_min
+from eval_masker import crop_and_resize
+
+# -----------------------
+# ----- Constants -----
+# -----------------------
+
+# Metrics
+metrics = ["error", "f05", "edge_coherence"]
+
+dict_metrics = {
+ "names": {
+ "tpr": "TPR, Recall, Sensitivity",
+ "tnr": "TNR, Specificity, Selectivity",
+ "fpr": "FPR",
+ "fpt": "False positives relative to image size",
+ "fnr": "FNR, Miss rate",
+ "fnt": "False negatives relative to image size",
+ "mpr": "May positive rate (MPR)",
+ "mnr": "May negative rate (MNR)",
+ "accuracy": "Accuracy (ignoring may)",
+ "error": "Error",
+ "f05": "F05 score",
+ "precision": "Precision",
+ "edge_coherence": "Edge coherence",
+ "accuracy_must_may": "Accuracy (ignoring cannot)",
+ },
+ "key_metrics": ["error", "f05", "edge_coherence"],
+}
+
+
+# Colors
+colorblind_palette = sns.color_palette("colorblind")
+color_cannot = colorblind_palette[1]
+color_must = colorblind_palette[2]
+color_may = colorblind_palette[7]
+color_pred = colorblind_palette[4]
+
+icefire = sns.color_palette("icefire", as_cmap=False, n_colors=5)
+color_tp = icefire[0]
+color_tn = icefire[1]
+color_fp = icefire[4]
+color_fn = icefire[3]
+
+
+def parsed_args():
+ """
+ Parse and returns command-line args
+
+ Returns:
+ argparse.Namespace: the parsed arguments
+ """
+ parser = ArgumentParser()
+ parser.add_argument(
+ "--input_csv",
+ default="ablations_metrics_20210311.csv",
+ type=str,
+ help="CSV containing the results of the ablation study",
+ )
+ parser.add_argument(
+ "--output_dir",
+ default=None,
+ type=str,
+ help="Output directory",
+ )
+ parser.add_argument(
+ "--models_log_path",
+ default=None,
+ type=str,
+ help="Path containing the log files of the models",
+ )
+ parser.add_argument(
+ "--masker_test_set_dir",
+ default=None,
+ type=str,
+ help="Directory containing the test images",
+ )
+ parser.add_argument(
+ "--best_model",
+ default="dada, msd_spade, pseudo",
+ type=str,
+ help="The string identifier of the best model",
+ )
+ parser.add_argument(
+ "--dpi",
+ default=200,
+ type=int,
+ help="DPI for the output images",
+ )
+ parser.add_argument(
+ "--alpha",
+ default=0.5,
+ type=float,
+ help="Transparency of labels shade",
+ )
+ parser.add_argument(
+ "--percentile",
+ default=0.05,
+ type=float,
+ help="Transparency of labels shade",
+ )
+ parser.add_argument(
+ "--seed",
+ default=None,
+ type=int,
+ help="Bootstrap random seed, for reproducibility",
+ )
+ parser.add_argument(
+ "--no_images",
+ action="store_true",
+ default=False,
+ help="Do not generate images",
+ )
+
+ return parser.parse_args()
+
+
+def map_color(arr, input_color, output_color, rtol=1e-09):
+ """
+ Maps one color to another
+ """
+ input_color_arr = np.tile(input_color, (arr.shape[:2] + (1,)))
+ output = arr.copy()
+ output[np.all(np.isclose(arr, input_color_arr, rtol=rtol), axis=2)] = output_color
+ return output
+
+
+def plot_labels(ax, img, label, img_id, n_, add_title, do_legend):
+ label_colmap = label.astype(float)
+ label_colmap = map_color(label_colmap, (255, 0, 0), color_cannot)
+ label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
+ label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
+
+ ax.imshow(img)
+ ax.imshow(label_colmap, alpha=0.5)
+ ax.axis("off")
+
+ if n_ in [1, 3, 5]:
+ color_ = "green"
+ else:
+ color_ = "red"
+
+ ax.text(
+ -0.15,
+ 0.5,
+ img_id,
+ color=color_,
+ fontweight="roman",
+ fontsize="x-large",
+ horizontalalignment="left",
+ verticalalignment="center",
+ transform=ax.transAxes,
+ )
+
+ if add_title:
+ ax.set_title("Labels", rotation=0, fontsize="x-large")
+
+
+def plot_pred(ax, img, pred, img_id, add_title, do_legend):
+ pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
+
+ pred_colmap = pred.astype(float)
+ pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
+ pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
+ pred_colmap_ma = pred_colmap_ma.mask * img + pred_colmap_ma
+
+ ax.imshow(img)
+ ax.imshow(pred_colmap_ma, alpha=0.5)
+ ax.axis("off")
+
+ if add_title:
+ ax.set_title("Prediction", rotation=0, fontsize="x-large")
+
+
+def plot_correct_incorrect(
+ ax, img_filename, img, metric, label, img_id, n_, add_title, do_legend
+):
+ # FP
+ fp_map = imread(
+ model_path / "eval-metrics/fp" / "{}_fp.png".format(Path(img_filename).stem)
+ )
+ fp_map = np.tile(np.expand_dims(fp_map, axis=2), reps=(1, 1, 3))
+
+ fp_map_colmap = fp_map.astype(float)
+ fp_map_colmap = map_color(fp_map_colmap, (1, 1, 1), color_fp)
+
+ # FN
+ fn_map = imread(
+ model_path / "eval-metrics/fn" / "{}_fn.png".format(Path(img_filename).stem)
+ )
+ fn_map = np.tile(np.expand_dims(fn_map, axis=2), reps=(1, 1, 3))
+
+ fn_map_colmap = fn_map.astype(float)
+ fn_map_colmap = map_color(fn_map_colmap, (1, 1, 1), color_fn)
+
+ # TP
+ tp_map = imread(
+ model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(img_filename).stem)
+ )
+ tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
+
+ tp_map_colmap = tp_map.astype(float)
+ tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
+
+ # TN
+ tn_map = imread(
+ model_path / "eval-metrics/tn" / "{}_tn.png".format(Path(img_filename).stem)
+ )
+ tn_map = np.tile(np.expand_dims(tn_map, axis=2), reps=(1, 1, 3))
+
+ tn_map_colmap = tn_map.astype(float)
+ tn_map_colmap = map_color(tn_map_colmap, (1, 1, 1), color_tn)
+
+ label_colmap = label.astype(float)
+ label_colmap = map_color(label_colmap, (0, 0, 0), color_may)
+ label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_may)
+ label_colmap_ma = label_colmap_ma.mask * img + label_colmap_ma
+
+ # Combine masks
+ maps = fp_map_colmap + fn_map_colmap + tp_map_colmap + tn_map_colmap
+ maps_ma = np.ma.masked_equal(maps, (0, 0, 0))
+ maps_ma = maps_ma.mask * img + maps_ma
+
+ ax.imshow(img)
+ ax.imshow(label_colmap_ma, alpha=0.5)
+ ax.imshow(maps_ma, alpha=0.5)
+ ax.axis("off")
+
+ if add_title:
+ ax.set_title("Metric", rotation=0, fontsize="x-large")
+
+
+def plot_edge_coherence(ax, img, metric, label, pred, img_id, n_, add_title, do_legend):
+ pred = np.tile(np.expand_dims(pred, axis=2), reps=(1, 1, 3))
+
+ ec, pred_ec, label_ec = edges_coherence_std_min(
+ np.squeeze(pred[:, :, 0]), np.squeeze(encode_mask_label(label, "flood"))
+ )
+
+ ##################
+ # Edge distances #
+ ##################
+
+ # Location of edges
+ pred_ec_coord = np.argwhere(pred_ec > 0)
+ label_ec_coord = np.argwhere(label_ec > 0)
+
+ # Normalized pairwise distances between pred and label
+ dist_mat = np.divide(
+ euclidean_distances(pred_ec_coord, label_ec_coord), pred_ec.shape[0]
+ )
+
+ # Standard deviation of the minimum distance from pred to label
+ min_dist = np.min(dist_mat, axis=1) # noqa: F841
+
+ #############
+ # Make plot #
+ #############
+
+ pred_ec = np.tile(
+ np.expand_dims(np.asarray(pred_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
+ )
+ pred_ec_colmap = map_color(pred_ec, (1, 1, 1), color_pred)
+ pred_ec_colmap_ma = np.ma.masked_not_equal(pred_ec_colmap, color_pred) # noqa: F841
+
+ label_ec = np.tile(
+ np.expand_dims(np.asarray(label_ec > 0, dtype=float), axis=2), reps=(1, 1, 3)
+ )
+ label_ec_colmap = map_color(label_ec, (1, 1, 1), color_must)
+ label_ec_colmap_ma = np.ma.masked_not_equal( # noqa: F841
+ label_ec_colmap, color_must
+ )
+
+ # Combined pred and label edges
+ combined_ec = pred_ec_colmap + label_ec_colmap
+ combined_ec_ma = np.ma.masked_equal(combined_ec, (0, 0, 0))
+ combined_ec_img = combined_ec_ma.mask * img + combined_ec
+
+ # Pred
+ pred_colmap = pred.astype(float)
+ pred_colmap = map_color(pred_colmap, (1, 1, 1), color_pred)
+ pred_colmap_ma = np.ma.masked_not_equal(pred_colmap, color_pred)
+
+ # Must
+ label_colmap = label.astype(float)
+ label_colmap = map_color(label_colmap, (0, 0, 255), color_must)
+ label_colmap_ma = np.ma.masked_not_equal(label_colmap, color_must)
+
+ # TP
+ tp_map = imread(
+ model_path / "eval-metrics/tp" / "{}_tp.png".format(Path(srs_sel.filename).stem)
+ )
+ tp_map = np.tile(np.expand_dims(tp_map, axis=2), reps=(1, 1, 3))
+ tp_map_colmap = tp_map.astype(float)
+ tp_map_colmap = map_color(tp_map_colmap, (1, 1, 1), color_tp)
+ tp_map_colmap_ma = np.ma.masked_not_equal(tp_map_colmap, color_tp)
+
+ # Combination
+ comb_pred = (
+ (pred_colmap_ma.mask ^ tp_map_colmap_ma.mask)
+ & tp_map_colmap_ma.mask
+ & combined_ec_ma.mask
+ ) * pred_colmap
+ comb_label = (
+ (label_colmap_ma.mask ^ pred_colmap_ma.mask)
+ & pred_colmap_ma.mask
+ & combined_ec_ma.mask
+ ) * label_colmap
+ comb_tp = combined_ec_ma.mask * tp_map_colmap.copy()
+ combined = comb_tp + comb_label + comb_pred
+ combined_ma = np.ma.masked_equal(combined, (0, 0, 0))
+ combined_ma = combined_ma.mask * combined_ec_img + combined_ma
+
+ ax.imshow(combined_ec_img, alpha=1)
+ ax.imshow(combined_ma, alpha=0.5)
+ ax.axis("off")
+
+ # Plot lines
+ idx_sort_x = np.argsort(pred_ec_coord[:, 1])
+ offset = 100
+ for idx in range(offset, pred_ec_coord.shape[0], offset):
+ y0, x0 = pred_ec_coord[idx_sort_x[idx], :]
+ argmin = np.argmin(dist_mat[idx_sort_x[idx]])
+ y1, x1 = label_ec_coord[argmin, :]
+ ax.plot([x0, x1], [y0, y1], color="white", linewidth=0.5)
+
+ if add_title:
+ ax.set_title("Metric", rotation=0, fontsize="x-large")
+
+
+def plot_images_metric(
+ axes, metric, img_filename, img_id, n_, srs_sel, add_title, do_legend
+):
+
+ # Read images
+ img_path = imgs_orig_path / img_filename
+ label_path = labels_path / "{}_labeled.png".format(Path(img_filename).stem)
+ img, label = crop_and_resize(img_path, label_path)
+ img = rgba2rgb(img) if img.shape[-1] == 4 else img / 255.0
+
+ pred = imread(
+ model_path / "eval-metrics/pred" / "{}_pred.png".format(Path(img_filename).stem)
+ )
+
+ # Label
+ plot_labels(axes[0], img, label, img_id, n_, add_title, do_legend)
+
+ # Prediction
+ plot_pred(axes[1], img, pred, img_id, add_title, do_legend)
+
+ # Correct / incorrect
+ if metric in ["error", "f05"]:
+ plot_correct_incorrect(
+ axes[2],
+ img_filename,
+ img,
+ metric,
+ label,
+ img_id,
+ n_,
+ add_title,
+ do_legend=False,
+ )
+ handles = []
+ lw = 1.0
+ handles.append(
+ mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(facecolor=color_tn, label="TN", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(facecolor=color_fp, label="FP", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(facecolor=color_fn, label="FN", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(
+ facecolor=color_may,
+ label="May-be-flooded",
+ linewidth=lw,
+ alpha=0.66,
+ )
+ )
+ labels = ["TP", "TN", "FP", "FN", "May-be-flooded"]
+ if metric == "error":
+ if n_ in [1, 3, 5]:
+ title = "Low error rate"
+ else:
+ title = "High error rate"
+ else:
+ if n_ in [1, 3, 5]:
+ title = "High F05 score"
+ else:
+ title = "Low F05 score"
+ # Edge coherence
+ elif metric == "edge_coherence":
+ plot_edge_coherence(
+ axes[2], img, metric, label, pred, img_id, n_, add_title, do_legend=False
+ )
+ handles = []
+ lw = 1.0
+ handles.append(
+ mpatches.Patch(facecolor=color_tp, label="TP", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(facecolor=color_pred, label="pred", linewidth=lw, alpha=0.66)
+ )
+ handles.append(
+ mpatches.Patch(
+ facecolor=color_must,
+ label="Must-be-flooded",
+ linewidth=lw,
+ alpha=0.66,
+ )
+ )
+ labels = ["TP", "Prediction", "Must-be-flooded"]
+ if n_ in [1, 3, 5]:
+ title = "High edge coherence"
+ else:
+ title = "Low edge coherence"
+
+ else:
+ raise ValueError
+
+ labels_values_title = "Error: {:.4f} \nFO5: {:.4f} \nEdge coherence: {:.4f}".format(
+ srs_sel.error, srs_sel.f05, srs_sel.edge_coherence
+ )
+
+ plot_legend(axes[3], img, handles, labels, labels_values_title, title)
+
+
+def plot_legend(ax, img, handles, labels, labels_values_title, title):
+ img_ = np.zeros_like(img, dtype=np.uint8)
+ img_.fill(255)
+ ax.imshow(img_)
+ ax.axis("off")
+
+ leg1 = ax.legend(
+ handles=handles,
+ labels=labels,
+ title=title,
+ title_fontsize="medium",
+ labelspacing=0.6,
+ loc="upper left",
+ fontsize="x-small",
+ frameon=False,
+ )
+ leg1._legend_box.align = "left"
+
+ leg2 = ax.legend(
+ title=labels_values_title,
+ title_fontsize="small",
+ loc="lower left",
+ frameon=False,
+ )
+ leg2._legend_box.align = "left"
+
+ ax.add_artist(leg1)
+
+
+def scatterplot_metrics_pair(ax, df, x_metric, y_metric, dict_images):
+
+ sns.scatterplot(data=df, x=x_metric, y=y_metric, ax=ax)
+
+ # Set X-label
+ ax.set_xlabel(dict_metrics["names"][x_metric], rotation=0, fontsize="medium")
+
+ # Set Y-label
+ ax.set_ylabel(dict_metrics["names"][y_metric], rotation=90, fontsize="medium")
+
+ # Change spines
+ sns.despine(ax=ax, left=True, bottom=True)
+
+ annotate_scatterplot(ax, dict_images, x_metric, y_metric)
+
+
+def scatterplot_metrics(ax, df, df_all, dict_images, plot_all=False):
+
+ # Other
+ if plot_all:
+ sns.scatterplot(
+ data=df_all.loc[df_all.ground == True],
+ x="error", y="f05", hue="edge_coherence", ax=ax,
+ marker='+', alpha=0.25)
+ sns.scatterplot(
+ data=df_all.loc[df_all.instagan == True],
+ x="error", y="f05", hue="edge_coherence", ax=ax,
+ marker='x', alpha=0.25)
+ sns.scatterplot(
+ data=df_all.loc[(df_all.instagan == False) & (df_all.instagan == False) &
+ (df_all.model_feats != args.best_model)],
+ x="error", y="f05", hue="edge_coherence", ax=ax,
+ marker='s', alpha=0.25)
+
+ # Best model
+ cmap_ = sns.cubehelix_palette(as_cmap=True)
+ sns.scatterplot(
+ data=df, x="error", y="f05", hue="edge_coherence", ax=ax, palette=cmap_
+ )
+
+ norm = plt.Normalize(df["edge_coherence"].min(), df["edge_coherence"].max())
+ sm = plt.cm.ScalarMappable(cmap=cmap_, norm=norm)
+ sm.set_array([])
+
+ # Remove the legend and add a colorbar
+ ax.get_legend().remove()
+ ax_cbar = ax.figure.colorbar(sm)
+ ax_cbar.set_label("Edge coherence", labelpad=8)
+
+ # Set X-label
+ ax.set_xlabel(dict_metrics["names"]["error"], rotation=0, fontsize="medium")
+
+ # Set Y-label
+ ax.set_ylabel(dict_metrics["names"]["f05"], rotation=90, fontsize="medium")
+
+ annotate_scatterplot(ax, dict_images, "error", "f05")
+
+ # Change spines
+ sns.despine(ax=ax, left=True, bottom=True)
+
+ # Set XY limits
+ xlim = ax.get_xlim()
+ ylim = ax.get_ylim()
+ ax.set_xlim([0.0, xlim[1]])
+ ax.set_ylim([ylim[0], 1.0])
+
+
+def annotate_scatterplot(ax, dict_images, x_metric, y_metric, offset=0.1):
+ xlim = ax.get_xlim()
+ ylim = ax.get_ylim()
+ x_len = xlim[1] - xlim[0]
+ y_len = ylim[1] - ylim[0]
+ x_th = xlim[1] - x_len / 2.0
+ y_th = ylim[1] - y_len / 2.0
+ for text, d in dict_images.items():
+ if text in ["B", "D", "F"]:
+ x = d[x_metric]
+ y = d[y_metric]
+
+ x_text = x + x_len * offset if x < x_th else x - x_len * offset
+ y_text = y + y_len * offset if y < y_th else y - y_len * offset
+
+ ax.annotate(
+ xy=(x, y),
+ xycoords="data",
+ xytext=(x_text, y_text),
+ textcoords="data",
+ text=text,
+ arrowprops=dict(facecolor="black", shrink=0.05),
+ fontsize="medium",
+ color="black",
+ )
+ elif text == "A":
+ x = (
+ dict_images["A"][x_metric]
+ + dict_images["C"][x_metric]
+ + dict_images["E"][x_metric]
+ ) / 3
+ y = (
+ dict_images["A"][y_metric]
+ + dict_images["C"][y_metric]
+ + dict_images["E"][y_metric]
+ ) / 3
+
+ x_text = x + x_len * 2 * offset if x < x_th else x - x_len * 2 * offset
+ y_text = (
+ y + y_len * 0.45 * offset if y < y_th else y - y_len * 0.45 * offset
+ )
+
+ ax.annotate(
+ xy=(x, y),
+ xycoords="data",
+ xytext=(x_text, y_text),
+ textcoords="data",
+ text="A, C, E",
+ arrowprops=dict(facecolor="black", shrink=0.05),
+ fontsize="medium",
+ color="black",
+ )
+
+
+if __name__ == "__main__":
+ # -----------------------------
+ # ----- Parse arguments -----
+ # -----------------------------
+ args = parsed_args()
+ print("Args:\n" + "\n".join([f" {k:20}: {v}" for k, v in vars(args).items()]))
+
+ # Determine output dir
+ if args.output_dir is None:
+ output_dir = Path(os.environ["SLURM_TMPDIR"])
+ else:
+ output_dir = Path(args.output_dir)
+ if not output_dir.exists():
+ output_dir.mkdir(parents=True, exist_ok=False)
+
+ # Store args
+ output_yml = output_dir / "labels.yml"
+ with open(output_yml, "w") as f:
+ yaml.dump(vars(args), f)
+
+ # Data dirs
+ imgs_orig_path = Path(args.masker_test_set_dir) / "imgs"
+ labels_path = Path(args.masker_test_set_dir) / "labels"
+
+ # Read CSV
+ df_all = pd.read_csv(args.input_csv, index_col="model_img_idx")
+
+ # Select best model
+ df = df_all.loc[df_all.model_feats == args.best_model]
+ v_key, model_dir = df.model.unique()[0].split("/")
+ model_path = Path(args.models_log_path) / "ablation-{}".format(v_key) / model_dir
+
+ # Set up plot
+ sns.reset_orig()
+ sns.set(style="whitegrid")
+ plt.rcParams.update({"font.family": "serif"})
+ plt.rcParams.update(
+ {
+ "font.serif": [
+ "Computer Modern Roman",
+ "Times New Roman",
+ "Utopia",
+ "New Century Schoolbook",
+ "Century Schoolbook L",
+ "ITC Bookman",
+ "Bookman",
+ "Times",
+ "Palatino",
+ "Charter",
+ "serif" "Bitstream Vera Serif",
+ "DejaVu Serif",
+ ]
+ }
+ )
+
+ if args.seed:
+ np.random.seed(args.seed)
+ img_ids = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
+ dict_images = {}
+ idx = 0
+
+ # Define grid of subplots
+ grid_vmargin = 0.03 # Extent of the vertical margin between metric grids
+ ax_hspace = 0.04 # Extent of the vertical space between axes of same grid
+ ax_wspace = 0.05 # Extent of the horizontal space between axes of same grid
+ n_grids = len(metrics)
+ n_cols = 4
+ n_rows = 2
+ h_grid = (1.0 / n_grids) - ((n_grids - 1) * grid_vmargin) / n_grids
+
+ fig1 = plt.figure(dpi=200, figsize=(11, 13))
+
+ n_ = 0
+ add_title = False
+ for metric_id, metric in enumerate(metrics):
+
+ # Create grid
+ top_grid = 1.0 - metric_id * h_grid - metric_id * grid_vmargin
+ bottom_grid = top_grid - h_grid
+ gridspec = GridSpec(
+ n_rows,
+ n_cols,
+ wspace=ax_wspace,
+ hspace=ax_hspace,
+ bottom=bottom_grid,
+ top=top_grid,
+ )
+
+ # Select best
+ if metric == "error":
+ ascending = True
+ else:
+ ascending = False
+ idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
+ srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
+ img_id = img_ids[idx]
+ dict_images.update({img_id: srs_sel})
+ # Read images
+ img_filename = srs_sel.filename
+
+ axes_row = [fig1.add_subplot(gridspec[0, c]) for c in range(n_cols)]
+ if not args.no_images:
+ n_ += 1
+ if metric_id == 0:
+ add_title = True
+ plot_images_metric(
+ axes_row,
+ metric,
+ img_filename,
+ img_id,
+ n_,
+ srs_sel,
+ add_title=add_title,
+ do_legend=False,
+ )
+ add_title = False
+
+ idx += 1
+ print("1 more row done.")
+ # Select worst
+ if metric == "error":
+ ascending = False
+ else:
+ ascending = True
+ idx_rand = np.random.permutation(int(args.percentile * len(df)))[0]
+ srs_sel = df.sort_values(by=metric, ascending=ascending).iloc[idx_rand]
+ img_id = img_ids[idx]
+ dict_images.update({img_id: srs_sel})
+ # Read images
+ img_filename = srs_sel.filename
+
+ axes_row = [fig1.add_subplot(gridspec[1, c]) for c in range(n_cols)]
+ if not args.no_images:
+ n_ += 1
+ plot_images_metric(
+ axes_row,
+ metric,
+ img_filename,
+ img_id,
+ n_,
+ srs_sel,
+ add_title=add_title,
+ do_legend=False,
+ )
+
+ idx += 1
+ print("1 more row done.")
+
+ output_fig = output_dir / "all_metrics.png"
+
+ fig1.tight_layout() # (pad=1.5) #
+ fig1.savefig(output_fig, dpi=fig1.dpi, bbox_inches="tight")
+
+ # Scatter plot
+ fig2 = plt.figure(dpi=200)
+
+ scatterplot_metrics(fig2.gca(), df, df_all, dict_images)
+
+ # fig2, axes = plt.subplots(nrows=1, ncols=3, dpi=200, figsize=(18, 5))
+ #
+ # scatterplot_metrics_pair(axes[0], df, "error", "f05", dict_images)
+ # scatterplot_metrics_pair(axes[1], df, "error", "edge_coherence", dict_images)
+ # scatterplot_metrics_pair(axes[2], df, "f05", "edge_coherence", dict_images)
+
+ output_fig = output_dir / "scatterplots.png"
+ fig2.savefig(output_fig, dpi=fig2.dpi, bbox_inches="tight")
diff --git a/inferences.py b/inferences.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5efe3a3283a4798b98fe74f6d343d6cc28eaae2
--- /dev/null
+++ b/inferences.py
@@ -0,0 +1,108 @@
+# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
+# thank you @NimaBoscarino
+
+import torch
+from skimage.color import rgba2rgb
+from skimage.transform import resize
+import numpy as np
+
+from climategan.trainer import Trainer
+
+
+def uint8(array):
+ """
+ convert an array to np.uint8 (does not rescale or anything else than changing dtype)
+ Args:
+ array (np.array): array to modify
+ Returns:
+ np.array(np.uint8): converted array
+ """
+ return array.astype(np.uint8)
+
+
+def resize_and_crop(img, to=640):
+ """
+ Resizes an image so that it keeps the aspect ratio and the smallest dimensions
+ is `to`, then crops this resized image in its center so that the output is `to x to`
+ without aspect ratio distortion
+ Args:
+ img (np.array): np.uint8 255 image
+ Returns:
+ np.array: [0, 1] np.float32 image
+ """
+ # resize keeping aspect ratio: smallest dim is 640
+ h, w = img.shape[:2]
+ if h < w:
+ size = (to, int(to * w / h))
+ else:
+ size = (int(to * h / w), to)
+
+ r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
+ r_img = uint8(r_img)
+
+ # crop in the center
+ H, W = r_img.shape[:2]
+
+ top = (H - to) // 2
+ left = (W - to) // 2
+
+ rc_img = r_img[top : top + to, left : left + to, :]
+
+ return rc_img / 255.0
+
+
+def to_m1_p1(img):
+ """
+ rescales a [0, 1] image to [-1, +1]
+ Args:
+ img (np.array): float32 numpy array of an image in [0, 1]
+ i (int): Index of the image being rescaled
+ Raises:
+ ValueError: If the image is not in [0, 1]
+ Returns:
+ np.array(np.float32): array in [-1, +1]
+ """
+ if img.min() >= 0 and img.max() <= 1:
+ return (img.astype(np.float32) - 0.5) * 2
+ raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
+
+
+# No need to do any timing in this, since it's just for the HF Space
+class ClimateGAN:
+ def __init__(self, model_path) -> None:
+ torch.set_grad_enabled(False)
+ self.target_size = 640
+ self.trainer = Trainer.resume_from_path(
+ model_path,
+ setup=True,
+ inference=True,
+ new_exp=None,
+ )
+
+ # Does all three inferences at the moment.
+ def inference(self, orig_image):
+ image = self._preprocess_image(orig_image)
+
+ # Retrieve numpy events as a dict {event: array[BxHxWxC]}
+ outputs = self.trainer.infer_all(
+ image,
+ numpy=True,
+ bin_value=0.5,
+ )
+
+ return (
+ outputs["flood"].squeeze(),
+ outputs["wildfire"].squeeze(),
+ outputs["smog"].squeeze(),
+ )
+
+ def _preprocess_image(self, img):
+ # rgba to rgb
+ data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
+
+ # to args.target_size
+ data = resize_and_crop(data, self.target_size)
+
+ # resize() produces [0, 1] images, rescale to [-1, 1]
+ data = to_m1_p1(data)
+ return data
diff --git a/requirements-3.8.2.txt b/requirements-3.8.2.txt
new file mode 100644
index 0000000000000000000000000000000000000000..f1d3eb3171924566ca4dffe37c080ae5d310b53f
--- /dev/null
+++ b/requirements-3.8.2.txt
@@ -0,0 +1,91 @@
+addict==2.4.0
+APScheduler==3.7.0
+attrs==21.2.0
+backcall==0.2.0
+Brotli==1.0.9
+certifi==2021.5.30
+charset-normalizer==2.0.4
+click==8.0.1
+codecarbon==1.2.0
+comet-ml==3.15.3
+configobj==5.0.6
+cycler==0.10.0
+dash==2.0.0
+dash-bootstrap-components==0.13.0
+dash-core-components==2.0.0
+dash-html-components==2.0.0
+dash-table==5.0.0
+dataclasses==0.6
+decorator==5.0.9
+dulwich==0.20.25
+everett==2.0.1
+filelock==3.0.12
+fire==0.4.0
+Flask==2.0.1
+Flask-Compress==1.10.1
+future==0.18.2
+gdown==3.13.0
+hydra-core==0.11.3
+idna==3.2
+imageio==2.9.0
+ipython==7.27.0
+itsdangerous==2.0.1
+jedi==0.18.0
+Jinja2==3.0.1
+joblib==1.0.1
+jsonschema==3.2.0
+kiwisolver==1.3.2
+kornia==0.5.10
+MarkupSafe==2.0.1
+matplotlib==3.4.3
+matplotlib-inline==0.1.2
+networkx==2.6.2
+numpy==1.21.2
+nvidia-ml-py3==7.352.0
+omegaconf==1.4.1
+opencv-python==4.5.3.56
+packaging==21.0
+pandas==1.3.2
+parso==0.8.2
+pexpect==4.8.0
+pickleshare==0.7.5
+Pillow==8.3.2
+plotly==5.3.1
+prompt-toolkit==3.0.20
+ptyprocess==0.7.0
+py-cpuinfo==8.0.0
+Pygments==2.10.0
+pynvml==11.0.0
+pyparsing==2.4.7
+pyrsistent==0.18.0
+PySocks==1.7.1
+python-dateutil==2.8.2
+pytorch-ranger==0.1.1
+pytz==2021.1
+PyWavelets==1.1.1
+PyYAML==5.4.1
+requests==2.26.0
+requests-toolbelt==0.9.1
+scikit-image==0.18.3
+scikit-learn==0.24.2
+scipy==1.7.1
+seaborn==0.11.2
+semantic-version==2.8.5
+six==1.16.0
+tenacity==8.0.1
+termcolor==1.1.0
+threadpoolctl==2.2.0
+tifffile==2021.8.30
+torch==1.7.1
+torch-optimizer==0.1.0
+torchvision==0.8.2
+tqdm==4.62.2
+traitlets==5.1.0
+typing-extensions==3.10.0.2
+tzlocal==2.1
+urllib3==1.26.6
+wcwidth==0.2.5
+websocket-client==1.2.1
+Werkzeug==2.0.1
+wrapt==1.12.1
+wurlitzer==3.0.2
\ No newline at end of file
diff --git a/requirements-any.txt b/requirements-any.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ac17ba237b7785696bebafe0e98c5394b33adda5
--- /dev/null
+++ b/requirements-any.txt
@@ -0,0 +1,20 @@
+addict
+codecarbon
+comet_ml
+hydra-core==0.11.3
+kornia
+omegaconf==1.4.1
+matplotlib
+numpy
+opencv-python
+packaging
+pandas
+PyYAML
+scikit-image
+scikit-learn
+scipy
+seaborn
+torch==1.7.0
+torch-optimizer
+torchvision==0.8.1
+tqdm
diff --git a/sbatch.py b/sbatch.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fb5cab4bfa57449479d252646b21c8d464e815f
--- /dev/null
+++ b/sbatch.py
@@ -0,0 +1,933 @@
+import datetime
+import itertools
+import os
+import re
+import subprocess
+import sys
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+import yaml
+
+
+def flatten_conf(conf, to={}, parents=[]):
+ """
+ Flattens a configuration dict: nested dictionaries are flattened
+ as key1.key2.key3 = value
+
+ conf.yaml:
+ ```yaml
+ a: 1
+ b:
+ c: 2
+ d:
+ e: 3
+ g:
+ sample: sequential
+ from: [4, 5]
+ ```
+
+ Is flattened to
+
+ {
+ "a": 1,
+ "b.c": 2,
+ "b.d.e": 3,
+ "b.g": {
+ "sample": "sequential",
+ "from": [4, 5]
+ }
+ }
+
+ Does not affect sampling dicts.
+
+ Args:
+ conf (dict): the configuration to flatten
+ new (dict, optional): the target flatenned dict. Defaults to {}.
+ parents (list, optional): a final value's list of parents. Defaults to [].
+ """
+ for k, v in conf.items():
+ if isinstance(v, dict) and "sample" not in v:
+ flatten_conf(v, to, parents + [k])
+ else:
+ new_k = ".".join([str(p) for p in parents + [k]])
+ to[new_k] = v
+
+
+def env_to_path(path):
+ """Transorms an environment variable mention in a json
+ into its actual value. E.g. $HOME/clouds -> /home/vsch/clouds
+
+ Args:
+ path (str): path potentially containing the env variable
+
+ """
+ path_elements = path.split("/")
+ new_path = []
+ for el in path_elements:
+ if "$" in el:
+ new_path.append(os.environ[el.replace("$", "")])
+ else:
+ new_path.append(el)
+ return "/".join(new_path)
+
+
+class C:
+ HEADER = "\033[95m"
+ OKBLUE = "\033[94m"
+ OKGREEN = "\033[92m"
+ WARNING = "\033[93m"
+ FAIL = "\033[91m"
+ ENDC = "\033[0m"
+ BOLD = "\033[1m"
+ UNDERLINE = "\033[4m"
+ ITALIC = "\33[3m"
+ BEIGE = "\33[36m"
+
+
+def escape_path(path):
+ p = str(path)
+ return p.replace(" ", "\ ").replace("(", "\(").replace(")", "\)") # noqa: W605
+
+
+def warn(*args, **kwargs):
+ print("{}{}{}".format(C.WARNING, " ".join(args), C.ENDC), **kwargs)
+
+
+def parse_jobID(command_output):
+ """
+ get job id from successful sbatch command output like
+ `Submitted batch job 599583`
+
+ Args:
+ command_output (str): sbatch command's output
+
+ Returns:
+ int: the slurm job's ID
+ """
+ command_output = command_output.strip()
+ if isinstance(command_output, str):
+ if "Submitted batch job" in command_output:
+ return int(command_output.split()[-1])
+
+ return -1
+
+
+def now():
+ return str(datetime.datetime.now()).replace(" ", "_")
+
+
+def cols():
+ try:
+ col = os.get_terminal_size().columns
+ except Exception:
+ col = 50
+ return col
+
+
+def print_box(txt):
+ if not txt:
+ txt = "{}{}ERROR ⇪{}".format(C.BOLD, C.FAIL, C.ENDC)
+ lt = 7
+ else:
+ lt = len(txt)
+ nlt = lt + 12
+ txt = "|" + " " * 5 + txt + " " * 5 + "|"
+ line = "-" * nlt
+ empty = "|" + " " * (nlt - 2) + "|"
+ print(line)
+ print(empty)
+ print(txt)
+ print(empty)
+ print(line)
+
+
+def print_header(idx):
+ b = C.BOLD
+ bl = C.OKBLUE
+ e = C.ENDC
+ char = "≡"
+ c = cols()
+
+ txt = " " * 20
+ txt += f"{b}{bl}Run {idx}{e}"
+ txt += " " * 20
+ ln = len(txt) - len(b) - len(bl) - len(e)
+ t = int(np.floor((c - ln) / 2))
+ tt = int(np.ceil((c - ln) / 2))
+
+ print(char * c)
+ print(char * t + " " * ln + char * tt)
+ print(char * t + txt + char * tt)
+ print(char * t + " " * ln + char * tt)
+ print(char * c)
+
+
+def print_footer():
+ c = cols()
+ char = "﹎"
+ print()
+ print(char * (c // len(char)))
+ print()
+ print(" " * (c // 2) + "•" + " " * (c - c // 2 - 1))
+ print()
+
+
+def extend_summary(summary, tmp_train_args_dict, tmp_template_dict, exclude=[]):
+ exclude = set(exclude)
+ if summary is None:
+ summary = defaultdict(list)
+ for k, v in tmp_template_dict.items():
+ if k not in exclude:
+ summary[k].append(v)
+ for k, v in tmp_train_args_dict.items():
+ if k not in exclude:
+ if isinstance(v, list):
+ v = str(v)
+ summary[k].append(v)
+ return summary
+
+
+def search_summary_table(summary, summary_dir=None):
+ # filter out constant values
+ summary = {k: v for k, v in summary.items() if len(set(v)) > 1}
+
+ # if everything is constant: no summary
+ if not summary:
+ return None, None
+
+ # find number of searches
+ n_searches = len(list(summary.values())[0])
+
+ # print section title
+ print(
+ "{}{}{}Varying values across {} experiments:{}\n".format(
+ C.OKBLUE,
+ C.BOLD,
+ C.UNDERLINE,
+ n_searches,
+ C.ENDC,
+ )
+ )
+
+ # first column holds the Exp. number
+ first_col = {
+ "len": 8, # length of a column, to split columns according to terminal width
+ "str": ["| Exp. |", "|:----:|"]
+ + [
+ "| {0:^{1}} |".format(i, 4) for i in range(n_searches)
+ ], # list of values to print
+ }
+
+ print_columns = [[first_col]]
+ file_columns = [first_col]
+ for k in sorted(summary.keys()):
+ v = summary[k]
+ col_title = f" {k} |"
+ col_blank_line = f":{'-' * len(k)}-|"
+ col_values = [
+ " {0:{1}} |".format(
+ crop_string(
+ str(crop_float(v[idx], min([5, len(k) - 2]))), len(k)
+ ), # crop floats and long strings
+ len(k),
+ )
+ for idx in range(len(v))
+ ]
+
+ # create column object
+ col = {"len": len(k) + 3, "str": [col_title, col_blank_line] + col_values}
+
+ # if adding a new column would overflow the terminal and mess up printing, start
+ # new set of columns
+ if sum(c["len"] for c in print_columns[-1]) + col["len"] >= cols():
+ print_columns.append([first_col])
+
+ # store current column to latest group of columns
+ print_columns[-1].append(col)
+ file_columns.append(col)
+
+ print_table = ""
+ # print each column group individually
+ for colgroup in print_columns:
+ # print columns line by line
+ for i in range(n_searches + 2):
+ # get value of column for current line i
+ for col in colgroup:
+ print_table += col["str"][i]
+ # next line for current columns
+ print_table += "\n"
+
+ # new lines for new column group
+ print_table += "\n"
+
+ file_table = ""
+ for i in range(n_searches + 2):
+ # get value of column for current line i
+ for col in file_columns:
+ file_table += col["str"][i]
+ # next line for current columns
+ file_table += "\n"
+
+ summary_path = None
+ if summary_dir is not None:
+ summary_path = summary_dir / (now() + ".md")
+ with summary_path.open("w") as f:
+ f.write(file_table.strip())
+
+ return print_table, summary_path
+
+
+def clean_arg(v):
+ """
+ chain cleaning function
+
+ Args:
+ v (any): arg to pass to train.py
+
+ Returns:
+ str: parsed value to string
+ """
+ return stringify_list(crop_float(quote_string(resolve_env(v))))
+
+
+def resolve_env(v):
+ """
+ resolve env variables in paths
+
+ Args:
+ v (any): arg to pass to train.py
+
+ Returns:
+ str: try and resolve an env variable
+ """
+ if isinstance(v, str):
+ try:
+ if "$" in v:
+ if "/" in v:
+ v = env_to_path(v)
+ else:
+ _v = os.environ.get(v)
+ if _v is not None:
+ v = _v
+ except Exception:
+ pass
+ return v
+
+
+def stringify_list(v):
+ """
+ Stringify list (with double quotes) so that it can be passed a an argument
+ to train.py's hydra command-line parsing
+
+ Args:
+ v (any): value to clean
+
+ Returns:
+ any: type of v, str if v was a list
+ """
+ if isinstance(v, list):
+ return '"{}"'.format(str(v).replace('"', "'"))
+ if isinstance(v, str):
+ if v.startswith("[") and v.endswith("]"):
+ return f'"{v}"'
+ return v
+
+
+def quote_string(v):
+ """
+ Add double quotes around string if it contains a " " or an =
+
+ Args:
+ v (any): value to clean
+
+ Returns:
+ any: type of v, quoted if v is a string with " " or =
+ """
+ if isinstance(v, str):
+ if " " in v or "=" in v:
+ return f'"{v}"'
+ return v
+
+
+def crop_float(v, k=5):
+ """
+ If v is a float, crop precision to 5 digits and return v as a str
+
+ Args:
+ v (any): value to crop if float
+
+ Returns:
+ any: cropped float as str if v is a float, original v otherwise
+ """
+ if isinstance(v, float):
+ return "{0:.{1}g}".format(v, k)
+ return v
+
+
+def compute_n_search(conf):
+ """
+ Compute the number of searchs to do if using -1 as n_search and using
+ cartesian or sequential search
+
+ Args:
+ conf (dict): experimental configuration
+
+ Returns:
+ int: size of the cartesian product or length of longest sequential field
+ """
+ samples = defaultdict(list)
+ for k, v in conf.items():
+ if not isinstance(v, dict) or "sample" not in v:
+ continue
+ samples[v["sample"]].append(v)
+
+ totals = []
+
+ if "cartesian" in samples:
+ total = 1
+ for s in samples["cartesian"]:
+ total *= len(s["from"])
+ totals.append(total)
+ if "sequential" in samples:
+ total = max(map(len, [s["from"] for s in samples["sequential"]]))
+ totals.append(total)
+
+ if totals:
+ return max(totals)
+
+ raise ValueError(
+ "Used n_search=-1 without any field being 'cartesian' or 'sequential'"
+ )
+
+
+def crop_string(s, k=10):
+ if len(s) <= k:
+ return s
+ else:
+ return s[: k - 2] + ".."
+
+
+def sample_param(sample_dict):
+ """sample a value (hyperparameter) from the instruction in the
+ sample dict:
+ {
+ "sample": "range | list",
+ "from": [min, max, step] | [v0, v1, v2 etc.]
+ }
+ if range, as np.arange is used, "from" MUST be a list, but may contain
+ only 1 (=min) or 2 (min and max) values, not necessarily 3
+
+ Args:
+ sample_dict (dict): instructions to sample a value
+
+ Returns:
+ scalar: sampled value
+ """
+ if not isinstance(sample_dict, dict) or "sample" not in sample_dict:
+ return sample_dict
+
+ if sample_dict["sample"] == "cartesian":
+ assert isinstance(
+ sample_dict["from"], list
+ ), "{}'s `from` field MUST be a list, found {}".format(
+ sample_dict["sample"], sample_dict["from"]
+ )
+ return "__cartesian__"
+
+ if sample_dict["sample"] == "sequential":
+ assert isinstance(
+ sample_dict["from"], list
+ ), "{}'s `from` field MUST be a list, found {}".format(
+ sample_dict["sample"], sample_dict["from"]
+ )
+ return "__sequential__"
+
+ if sample_dict["sample"] == "range":
+ return np.random.choice(np.arange(*sample_dict["from"]))
+
+ if sample_dict["sample"] == "list":
+ return np.random.choice(sample_dict["from"])
+
+ if sample_dict["sample"] == "uniform":
+ return np.random.uniform(*sample_dict["from"])
+
+ raise ValueError("Unknown sample type in dict " + str(sample_dict))
+
+
+def sample_sequentials(sequential_keys, exp, idx):
+ """
+ Samples sequentially from the "from" values specified in each key of the
+ experimental configuration which have sample == "sequential"
+ Unlike `cartesian` sampling, `sequential` sampling iterates *independently*
+ over each keys
+
+ Args:
+ sequential_keys (list): keys to be sampled sequentially
+ exp (dict): experimental config
+ idx (int): index of the current sample
+
+ Returns:
+ conf: sampled dict
+ """
+ conf = {}
+ for k in sequential_keys:
+ v = exp[k]["from"]
+ conf[k] = v[idx % len(v)]
+ return conf
+
+
+def sample_cartesians(cartesian_keys, exp, idx):
+ """
+ Returns the `idx`th item in the cartesian product of all cartesian keys to
+ be sampled.
+
+ Args:
+ cartesian_keys (list): keys in the experimental configuration that are to
+ be used in the full cartesian product
+ exp (dict): experimental configuration
+ idx (int): index of the current sample
+
+ Returns:
+ dict: sampled point in the cartesian space (with keys = cartesian_keys)
+ """
+ conf = {}
+ cartesian_values = [exp[key]["from"] for key in cartesian_keys]
+ product = list(itertools.product(*cartesian_values))
+ for k, v in zip(cartesian_keys, product[idx % len(product)]):
+ conf[k] = v
+ return conf
+
+
+def resolve(hp_conf, nb):
+ """
+ Samples parameters parametrized in `exp`: should be a dict with
+ values which fit `sample_params(dic)`'s API
+
+ Args:
+ exp (dict): experiment's parametrization
+ nb (int): number of experiments to sample
+
+ Returns:
+ dict: sampled configuration
+ """
+ if nb == -1:
+ nb = compute_n_search(hp_conf)
+
+ confs = []
+ for idx in range(nb):
+ conf = {}
+ cartesians = []
+ sequentials = []
+ for k, v in hp_conf.items():
+ candidate = sample_param(v)
+ if candidate == "__cartesian__":
+ cartesians.append(k)
+ elif candidate == "__sequential__":
+ sequentials.append(k)
+ else:
+ conf[k] = candidate
+ if sequentials:
+ conf.update(sample_sequentials(sequentials, hp_conf, idx))
+ if cartesians:
+ conf.update(sample_cartesians(cartesians, hp_conf, idx))
+ confs.append(conf)
+ return confs
+
+
+def get_template_params(template):
+ """
+ extract args in template str as {arg}
+
+ Args:
+ template (str): sbatch template string
+
+ Returns:
+ list(str): Args required to format the template string
+ """
+ return map(
+ lambda s: s.replace("{", "").replace("}", ""),
+ re.findall("\{.*?\}", template), # noqa: W605
+ )
+
+
+def read_exp_conf(name):
+ """
+ Read hp search configuration from shared/experiment/
+ specified with or without the .yaml extension
+
+ Args:
+ name (str): name of the template to find in shared/experiment/
+
+ Returns:
+ Tuple(Path, dict): file path and loaded dict
+ """
+ if ".yaml" not in name:
+ name += ".yaml"
+ paths = []
+ dirs = ["shared", "config"]
+ for d in dirs:
+ path = Path(__file__).parent / d / "experiment" / name
+ if path.exists():
+ paths.append(path.resolve())
+
+ if len(paths) == 0:
+ failed = [Path(__file__).parent / d / "experiment" for d in dirs]
+ s = "Could not find search config {} in :\n".format(name)
+ for fd in failed:
+ s += str(fd) + "\nAvailable:\n"
+ for ym in fd.glob("*.yaml"):
+ s += " " + ym.name + "\n"
+ raise ValueError(s)
+
+ if len(paths) == 2:
+ print(
+ "Warning: found 2 relevant files for search config:\n{}".format(
+ "\n".join(paths)
+ )
+ )
+ print("Using {}".format(paths[-1]))
+
+ with paths[-1].open("r") as f:
+ conf = yaml.safe_load(f)
+
+ flat_conf = {}
+ flatten_conf(conf, to=flat_conf)
+
+ return (paths[-1], flat_conf)
+
+
+def read_template(name):
+ """
+ Read template from shared/template/ specified with or without the .sh extension
+
+ Args:
+ name (str): name of the template to find in shared/template/
+
+ Returns:
+ str: file's content as 1 string
+ """
+ if ".sh" not in name:
+ name += ".sh"
+ paths = []
+ dirs = ["shared", "config"]
+ for d in dirs:
+ path = Path(__file__).parent / d / "template" / name
+ if path.exists():
+ paths.append(path)
+
+ if len(paths) == 0:
+ failed = [Path(__file__).parent / d / "template" for d in dirs]
+ s = "Could not find template {} in :\n".format(name)
+ for fd in failed:
+ s += str(fd) + "\nAvailable:\n"
+ for ym in fd.glob("*.sh"):
+ s += " " + ym.name + "\n"
+ raise ValueError(s)
+
+ if len(paths) == 2:
+ print("Warning: found 2 relevant template files:\n{}".format("\n".join(paths)))
+ print("Using {}".format(paths[-1]))
+
+ with paths[-1].open("r") as f:
+ return f.read()
+
+
+def is_sampled(key, conf):
+ """
+ Is a key sampled or constant? Returns true if conf is empty
+
+ Args:
+ key (str): key to check
+ conf (dict): hyper parameter search configuration dict
+
+ Returns:
+ bool: key is sampled?
+ """
+ return not conf or (
+ key in conf and isinstance(conf[key], dict) and "sample" in conf[key]
+ )
+
+
+if __name__ == "__main__":
+
+ """
+ Notes:
+ * Must provide template name as template=name
+ * `name`.sh should be in shared/template/
+ """
+
+ # -------------------------------
+ # ----- Default Variables -----
+ # -------------------------------
+
+ args = sys.argv[1:]
+ command_output = ""
+ user = os.environ.get("USER")
+ home = os.environ.get("HOME")
+ exp_conf = {}
+ dev = False
+ escape = False
+ verbose = False
+ template_name = None
+ hp_exp_name = None
+ hp_search_nb = None
+ exp_path = None
+ resume = None
+ force_sbatchs = False
+ sbatch_base = Path(home) / "climategan_sbatchs"
+ summary_dir = Path(home) / "climategan_exp_summaries"
+
+ hp_search_private = set(["n_search", "template", "search", "summary_dir"])
+
+ sbatch_path = "hash"
+
+ # --------------------------
+ # ----- Sanity Check -----
+ # --------------------------
+
+ for arg in args:
+ if "=" not in arg or " = " in arg:
+ raise ValueError(
+ "Args should be passed as `key=value`. Received `{}`".format(arg)
+ )
+
+ # --------------------------------
+ # ----- Parse Command Line -----
+ # --------------------------------
+
+ args_dict = {arg.split("=")[0]: arg.split("=")[1] for arg in args}
+
+ assert "template" in args_dict, "Please specify template=xxx"
+ template = read_template(args_dict["template"])
+ template_dict = {k: None for k in get_template_params(template)}
+
+ train_args = []
+ for k, v in args_dict.items():
+
+ if k == "verbose":
+ if v != "0":
+ verbose = True
+
+ elif k == "sbatch_path":
+ sbatch_path = v
+
+ elif k == "sbatch_base":
+ sbatch_base = Path(v).resolve()
+
+ elif k == "force_sbatchs":
+ force_sbatchs = v.lower() == "true"
+
+ elif k == "dev":
+ if v.lower() != "false":
+ dev = True
+
+ elif k == "escape":
+ if v.lower() != "false":
+ escape = True
+
+ elif k == "template":
+ template_name = v
+
+ elif k == "exp":
+ hp_exp_name = v
+
+ elif k == "n_search":
+ hp_search_nb = int(v)
+
+ elif k == "resume":
+ resume = f'"{v}"'
+ template_dict[k] = f'"{v}"'
+
+ elif k == "summary_dir":
+ if v.lower() == "none":
+ summary_dir = None
+ else:
+ summary_dir = Path(v)
+
+ elif k in template_dict:
+ template_dict[k] = v
+
+ else:
+ train_args.append(f"{k}={v}")
+
+ # ------------------------------------
+ # ----- Load Experiment Config -----
+ # ------------------------------------
+
+ if hp_exp_name is not None:
+ exp_path, exp_conf = read_exp_conf(hp_exp_name)
+ if "n_search" in exp_conf and hp_search_nb is None:
+ hp_search_nb = exp_conf["n_search"]
+
+ assert (
+ hp_search_nb is not None
+ ), "n_search should be specified in a yaml file or from the command line"
+
+ hps = resolve(exp_conf, hp_search_nb)
+
+ else:
+ hps = [None]
+
+ # ---------------------------------
+ # ----- Run All Experiments -----
+ # ---------------------------------
+ if summary_dir is not None:
+ summary_dir.mkdir(exist_ok=True, parents=True)
+ summary = None
+
+ for hp_idx, hp in enumerate(hps):
+
+ # copy shared values
+ tmp_template_dict = template_dict.copy()
+ tmp_train_args = train_args.copy()
+ tmp_train_args_dict = {
+ arg.split("=")[0]: arg.split("=")[1] for arg in tmp_train_args
+ }
+ print_header(hp_idx)
+ # override shared values with run-specific values for run hp_idx/n_search
+ if hp is not None:
+ for k, v in hp.items():
+ if k == "resume" and resume is None:
+ resume = f'"{v}"'
+ # hp-search params to ignore
+ if k in hp_search_private:
+ continue
+
+ if k == "codeloc":
+ v = escape_path(v)
+
+ if k == "output":
+ Path(v).parent.mkdir(parents=True, exist_ok=True)
+
+ # override template params depending on exp config
+ if k in tmp_template_dict:
+ if template_dict[k] is None or is_sampled(k, exp_conf):
+ tmp_template_dict[k] = v
+ # store sampled / specified params in current tmp_train_args_dict
+ else:
+ if k in tmp_train_args_dict:
+ if is_sampled(k, exp_conf):
+ # warn if key was specified from the command line
+ tv = tmp_train_args_dict[k]
+ warn(
+ "\nWarning: overriding sampled config-file arg",
+ "{} to command-line value {}\n".format(k, tv),
+ )
+ else:
+ tmp_train_args_dict[k] = v
+
+ # create sbatch file where required
+ tmp_sbatch_path = None
+ if sbatch_path == "hash":
+ tmp_sbatch_name = "" if hp_exp_name is None else hp_exp_name[:14] + "_"
+ tmp_sbatch_name += now() + ".sh"
+ tmp_sbatch_path = sbatch_base / tmp_sbatch_name
+ tmp_sbatch_path.parent.mkdir(parents=True, exist_ok=True)
+ tmp_train_args_dict["sbatch_file"] = str(tmp_sbatch_path)
+ tmp_train_args_dict["exp_file"] = str(exp_path)
+ else:
+ tmp_sbatch_path = Path(sbatch_path).resolve()
+
+ summary = extend_summary(
+ summary, tmp_train_args_dict, tmp_template_dict, exclude=["sbatch_file"]
+ )
+
+ # format train.py's args and crop floats' precision to 5 digits
+ tmp_template_dict["train_args"] = " ".join(
+ sorted(
+ [
+ "{}={}".format(k, clean_arg(v))
+ for k, v in tmp_train_args_dict.items()
+ ]
+ )
+ )
+
+ if "resume.py" in template and resume is None:
+ raise ValueError("No `resume` value but using a resume.py template")
+
+ # format template with clean dict (replace None with "")
+ sbatch = template.format(
+ **{
+ k: v if v is not None else ""
+ for k, v in tmp_template_dict.items()
+ if k in template_dict
+ }
+ )
+
+ # --------------------------------------
+ # ----- Execute `sbatch` Command -----
+ # --------------------------------------
+ if not dev or force_sbatchs:
+ if tmp_sbatch_path.exists():
+ print(f"Warning: overwriting {sbatch_path}")
+
+ # write sbatch file
+ with open(tmp_sbatch_path, "w") as f:
+ f.write(sbatch)
+
+ if not dev:
+ # escape special characters such as " " from sbatch_path's parent dir
+ parent = str(tmp_sbatch_path.parent)
+ if escape:
+ parent = escape_path(parent)
+
+ # create command to execute in a subprocess
+ command = "sbatch {}".format(tmp_sbatch_path.name)
+ # execute sbatch command & store output
+ command_output = subprocess.run(
+ command.split(), stdout=subprocess.PIPE, cwd=parent
+ )
+ command_output = "\n" + command_output.stdout.decode("utf-8") + "\n"
+
+ print(f"Running from {parent}:")
+ print(f"$ {command}")
+
+ # ---------------------------------
+ # ----- Summarize Execution -----
+ # ---------------------------------
+ if verbose:
+ print(C.BEIGE + C.ITALIC, "\n" + sbatch + C.ENDC)
+ if not dev:
+ print_box(command_output.strip())
+ jobID = parse_jobID(command_output.strip())
+ summary["Slurm JOBID"].append(jobID)
+
+ summary["Comet Link"].append(f"[{hp_idx}][{hp_idx}]")
+
+ print(
+ "{}{}Summary{} {}:".format(
+ C.UNDERLINE,
+ C.OKGREEN,
+ C.ENDC,
+ f"{C.WARNING}(DEV){C.ENDC}" if dev else "",
+ )
+ )
+ print(
+ " "
+ + "\n ".join(
+ "{:10}: {}".format(k, v) for k, v in tmp_template_dict.items()
+ )
+ )
+ print_footer()
+
+ print(f"\nRan a total of {len(hps)} jobs{' in dev mode.' if dev else '.'}\n")
+
+ table, sum_path = search_summary_table(summary, summary_dir if not dev else None)
+ if table is not None:
+ print(table)
+ print(
+ "Add `[i]: https://...` at the end of a markdown document",
+ "to fill in the comet links.\n",
+ )
+ if summary_dir is None:
+ print("Add summary_dir=path to store the printed markdown table ⇪")
+ else:
+ print("Saved table in", str(sum_path))
+
+ if not dev:
+ print(
+ "Cancel entire experiment? \n$ scancel",
+ " ".join(map(str, summary["Slurm JOBID"])),
+ )
diff --git a/shared/experiment/showcase.yaml b/shared/experiment/showcase.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..1a6c89ad59db5d358ed61b31e67e72f844dcb2cc
--- /dev/null
+++ b/shared/experiment/showcase.yaml
@@ -0,0 +1,71 @@
+--- # ---------------------------
+# `sample` can be
+# - `uniform` (np.random.uniform(*from))
+# - `range` (np.choice(np.arange(*from)))
+# - `list` (np.choice(from))
+# - `cartesian` special case where a cartesian product of all keys with the `cartesian` sampling scheme
+# is created and iterated over in order. `from` MUST be a list
+# As we iterate over the cartesian product of all
+# such keys, others are sampled as usual. If n_search is larger than the size of the cartesian
+# product, it will cycle again through the product in the same order
+# example with A being `cartesian` from [1, 2] and B from [y, z] and 5 searches:
+# => {A:1, B: y}, {A:1, B: z}, {A:2, B: y}, {A:2, B: z}, {A:1, B: y}
+# - `sequential` samples will loop through the values in `from`. `from` MUST be a list
+
+# ---------------------------
+# ----- SBATCH config -----
+cpus: 8
+partition: long
+mem: 32G
+gres: "gpu:rtx8000:1"
+codeloc: $HOME/ccai/climategan
+
+modules: "module load anaconda/3 && module load pytorch"
+conda: "conda activate climatenv && conda deactivate && conda activate climatenv"
+
+n_search: -1
+
+# ------------------------
+# ----- Train Args -----
+# ------------------------
+
+"args.note": "Hyper Parameter search #1"
+"args.comet_tags": ["masker_search", "v1"]
+"args.config": "config/trainer/my_config.yaml"
+
+# --------------------------
+# ----- Model config -----
+# --------------------------
+"gen.opt.lr":
+ sample: list
+ from: [0.01, 0.001, 0.0001, 0.00001]
+
+"dis.opt.lr":
+ sample: uniform
+ from: [0.01, 0.001]
+
+"dis.opt.optimizer":
+ sample: cartesian
+ from:
+ - ExtraAdam
+ - Adam
+
+"gen.opt.optimizer":
+ sample: cartesian
+ from:
+ - ExtraAdam
+ - Adam
+
+"gen.lambdas.C":
+ sample: cartesian
+ from:
+ - 0.1
+ - 0.5
+ - 1
+
+"data.loaders.batch_size":
+ sample: sequential
+ from:
+ - 2
+ - 4
+ - 6
diff --git a/shared/template/mila_victor.sh b/shared/template/mila_victor.sh
new file mode 100644
index 0000000000000000000000000000000000000000..343d88c5094be656c76fe0348709f2a5c44875b3
--- /dev/null
+++ b/shared/template/mila_victor.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+#SBATCH --partition={partition}
+#SBATCH --cpus-per-task={cpus}
+#SBATCH --mem={mem}
+#SBATCH --gres={gres}
+#SBATCH --output={output}
+
+module purge
+
+{modules}
+
+{conda}
+
+export PYTHONUNBUFFERED=1
+
+cd {codeloc}
+
+echo "Currently using:"
+echo $(which python)
+echo "in:"
+echo $(pwd)
+echo "sbatch file name: $0"
+
+python train.py {train_args}
\ No newline at end of file
diff --git a/shared/template/resume_mila_victor.sh b/shared/template/resume_mila_victor.sh
new file mode 100644
index 0000000000000000000000000000000000000000..2a5bcac63bdf841406afc9718a31dcfc8bf4df33
--- /dev/null
+++ b/shared/template/resume_mila_victor.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+#SBATCH --partition={partition}
+#SBATCH --cpus-per-task={cpus}
+#SBATCH --mem={mem}
+#SBATCH --gres={gres}
+#SBATCH --output={output}
+
+module purge
+
+{modules}
+
+{conda}
+
+export PYTHONUNBUFFERED=1
+
+cd {codeloc}
+
+echo "Currently using:"
+echo $(which python)
+echo "in:"
+echo $(pwd)
+echo "sbatch file: $0"
+
+python resume.py --path {resume}
\ No newline at end of file
diff --git a/shared/trainer/config.yaml b/shared/trainer/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..b423d40e2086dec526f68d69ca02d67a8cb76fe4
--- /dev/null
+++ b/shared/trainer/config.yaml
@@ -0,0 +1,16 @@
+# HYDRA CONFIG
+
+# defaults:
+# - defaults
+
+args:
+ config: null # "What configuration file to use to overwrite shared/defaults.yaml"
+ note: null # Note about this training for comet logging
+ no_comet: False # DON'T use comet.ml to log experiment
+ resume: False # Load latest ckpt
+ tags: null
+ dev: False # Run this script in development mode
+
+hydra:
+ run:
+ dir: .
diff --git a/shared/trainer/defaults.yaml b/shared/trainer/defaults.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d5d08cd61ca0680325ea7d0abe3cadbcbcc4b4e9
--- /dev/null
+++ b/shared/trainer/defaults.yaml
@@ -0,0 +1,334 @@
+output_path: /miniscratch/_groups/ccai/trash
+# README on load_path
+# 1/ any path which leads to a dir will be loaded as `path / checkpoints / latest_ckpt.pth`
+# 2/ if you want to specify a specific checkpoint, it MUST be a `.pth` file
+# 3/ resuming a P OR an M model, you may only specify 1 of `load_path.p` OR `load_path.m`.
+# You may also leave BOTH at none, in which case `output_path / checkpoints / latest_ckpt.pth`
+# will be used
+# 4/ resuming a P+M model, you may specify (`p` AND `m`) OR `pm` OR leave all at none,
+# in which case `output_path / checkpoints / latest_ckpt.pth` will be used to load from
+# a single checkpoint
+load_paths:
+ p: none # Painter weights: none will use `output_path / checkpoints / latest_ckpt.pth`
+ m: none # Masker weights: none will use `output_path / checkpoints / latest_ckpt.pth`
+ pm: none # Painter and Masker weights: none will use `output_path / checkpoints / latest_ckpt.pth`
+
+# -------------------
+# ----- Tasks -----
+# -------------------
+tasks: [d, s, m, p] # [p] [m, s, d]
+
+# ----------------
+# ----- Data -----
+# ----------------
+data:
+ max_samples: -1 # -1 for all, otherwise set to an int to crop the training data size
+ files: # if one is not none it will override the dirs location
+ base: /miniscratch/_groups/ccai/data/jsons
+ train:
+ r: train_r_full.json
+ s: train_s_fixedholes.json
+ rf: train_rf.json
+ kitti: train_kitti.json
+ val:
+ r: val_r_full.json
+ s: val_s_fixedholes.json
+ rf: val_rf_labelbox.json
+ kitti: val_kitti.json
+ check_samples: False
+ loaders:
+ batch_size: 6
+ num_workers: 6
+ normalization: default # can be "default" or "HRNet" for now. # default: mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]; HRNet: mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
+ transforms:
+ - name: hflip
+ ignore: val
+ p: 0.5
+ - name: resize
+ ignore: false
+ new_size: 640
+ keep_aspect_ratio: true # smallest dimension will be `new_size` and the other will be computed to keep aspect ratio
+ - name: crop
+ ignore: false
+ center: val # disable randomness, crop around the image's center
+ height: 600
+ width: 600
+ - name: brightness
+ ignore: val
+ - name: saturation
+ ignore: val
+ - name: contrast
+ ignore: val
+ - name: resize
+ ignore: false
+ new_size:
+ default: 640
+ d: 160
+ s: 160
+
+# ---------------------
+# ----- Generator -----
+# ---------------------
+gen:
+ opt:
+ optimizer: ExtraAdam # one in [Adam, ExtraAdam] default: Adam
+ beta1: 0.9
+ lr:
+ default: 0.00005 # 0.00001 for dlv2, 0.00005 for dlv3
+ lr_policy: step
+ # lr_policy can be constant, step or multi_step; if step, specify lr_step_size and lr_gamma
+ # if multi_step specify lr_step_size lr_gamma and lr_milestones:
+ # if lr_milestones is a list:
+ # the learning rate will be multiplied by gamma each time the epoch reaches an
+ # item in the list (no need for lr_step_size).
+ # if lr_milestones is an int:
+ # a list of milestones is created from `range(lr_milestones, train.epochs, lr_step_size)`
+ lr_step_size: 5 # for linear decay : period of learning rate decay (epochs)
+ lr_milestones: 15
+ lr_gamma: 0.5 # Multiplicative factor of learning rate decay
+ default:
+ &default-gen # default parameters for the generator (encoder and decoders)
+ activ: lrelu # activation function [relu/lrelu/prelu/selu/tanh]
+ init_gain: 0.02
+ init_type: xavier
+ n_res: 1 # number of residual blocks before upsampling
+ n_downsample: &n_downsample 3 # number of downsampling layers in encoder | dim 32 + down 3 => z = 256 x 32 x 32
+ n_upsample: *n_downsample # upsampling in spade decoder ; should match encoder.n_downsample
+ pad_type: reflect # padding type [zero/reflect]
+ norm: spectral # ResBlock normalization ; one of {"batch", "instance", "layer", "adain", "spectral", "none"}
+ proj_dim: 32 # Dim of projection from latent space
+ encoder: # specific params for the encoder
+ <<: *default-gen
+ dim: 32
+ architecture: deeplabv3 # [deeplabv2/v3 resnet -> res_dim=2048) | dlv3 mobilenet -> res_dim=320
+ input_dim: 3 # input number of channels
+ n_res: 0 # number of residual blocks in content encoder/decoder
+ norm: spectral # ConvBlock normalization ; one of {"batch", "instance", "layer", "adain", "spectral", "none"}
+
+ #! Don't change!!!
+ deeplabv2:
+ nblocks: [3, 4, 23, 3]
+ use_pretrained: True
+ pretrained_model: "/miniscratch/_groups/ccai/data/pretrained_models/deeplabv2/DeepLab_resnet_pretrained_imagenet.pth"
+
+ deeplabv3:
+ backbone: resnet # resnet or mobilenet
+ output_stride: 8 # 8 or 16
+ use_pretrained: true
+ pretrained_model:
+ mobilenet: "/miniscratch/_groups/ccai/data/pretrained_models/deeplabv3/deeplabv3_plus_mobilenetv2_segmentron.pth"
+ resnet: "/miniscratch/_groups/ccai/data/pretrained_models/deeplabv3/model_CoinCheungDeepLab-v3-plus.pth"
+
+ d: # specific params for the depth estimation decoder
+ <<: *default-gen
+ output_dim: 1
+ norm: batch
+ loss: sigm # dada or sigm | /!\ ignored if classify.enable
+ upsample_featuremaps: True # upsamples from 80x80 to 160x160 intermediate feature maps
+ architecture: dada # dada or base | must be base for classif
+ classify: # classify log-depth instead of regression
+ enable: False
+ linspace:
+ min: 0.35
+ max: 6.95
+ buckets: 256
+ s: # specific params for the semantic segmentation decoder
+ <<: *default-gen
+ num_classes: 11
+ output_dim: 11
+ use_advent: True
+ use_minent: True
+ architecture: deeplabv3
+ upsample_featuremaps: False # upsamples from 80x80 to 160x160 intermediate feature maps
+ use_dada: True
+ p: # specific params for the SPADE painter
+ <<: *default-gen
+ latent_dim: 640
+ loss: gan # gan or hinge
+ no_z: true # <=> use_vae=False in the SPADE repo
+ output_dim: 3 # output dimension
+ pad_type: reflect # padding type [zero/reflect]
+ paste_original_content: True # only select the water painted to backprop through the network, not the whole generated image: fake_flooded = masked_x + m * fake_flooded
+ pl4m_epoch: 49 # epoch from which we introduce a new loss to the masker: the painter's discriminator's loss
+ spade_kernel_size: 3 # kernel size within SPADE norm layers
+ spade_n_up: 7 # number of upsampling layers in the translation decoder is equal to number of downsamplings in the encoder. output's h and w are z's h and w x 2^spade_num_upsampling_layers | z:32 and spade_n_up:4 => output 512
+ spade_param_free_norm: instance # what param-free normalization to apply in SPADE normalization
+ spade_use_spectral_norm: true
+ use_final_shortcut: False # if true, the last spade block does not get the masked input as conditioning but the prediction of the previous layer (passed through a conv to match dims) in order to lighten the masking restrictions and have smoother edges
+ diff_aug:
+ use: False
+ do_color_jittering: false
+ do_cutout: false
+ cutout_ratio: 0.5
+ do_translation: false
+ translation_ratio: 0.125
+
+ m: # specific params for the mask-generation decoder
+ <<: *default-gen
+ use_spade: False
+ output_dim: 1
+ use_minent: True # directly minimize the entropy of the image
+ use_minent_var: True # add variance of entropy map in the measure of entropy for a certain picture
+ use_advent: True # minimize the entropy of the image by adversarial training
+ use_ground_intersection: True
+ use_proj: True
+ proj_dim: 64
+ use_pl4m: False
+ n_res: 3
+ use_low_level_feats: True
+ use_dada: False
+ spade:
+ latent_dim: 128
+ detach: false # detach s_pred and d_pred conditioning tensors
+ cond_nc: 15 # 12 without x, 15 with x
+ spade_use_spectral_norm: True
+ spade_param_free_norm: batch
+ num_layers: 3
+ activations:
+ all_lrelu: True
+
+# -------------------------
+# ----- Discriminator -----
+# -------------------------
+dis:
+ soft_shift: 0.2 # label smoothing: real in U(1-soft_shift, 1), fake in U(0, soft_shift) # ! one-sided label smoothing
+ flip_prob: 0.05 # label flipping
+ opt:
+ optimizer: ExtraAdam # one in [Adam, ExtraAdam] default: Adam
+ beta1: 0.5
+ lr:
+ default: 0.00002 # 0.0001 for dlv2, 0.00002 for dlv3
+ lr_policy: step
+ # lr_policy can be constant, step or multi_step; if step, specify lr_step_size and lr_gamma
+ # if multi_step specify lr_step_size lr_gamma and lr_milestones:
+ # if lr_milestones is a list:
+ # the learning rate will be multiplied by gamma each time the epoch reaches an
+ # item in the list (no need for lr_step_size).
+ # if lr_milestones is an int:
+ # a list of milestones is created from `range(lr_milestones, train.epochs, lr_step_size)`
+ lr_step_size: 15 # for linear decay : period of learning rate decay (epochs)
+ lr_milestones: 5
+ lr_gamma: 0.5 # Multiplicative factor of learning rate decay
+ default:
+ &default-dis # default setting for discriminators (there are 4 of them for rn rf sn sf)
+ input_nc: 3
+ ndf: 64
+ n_layers: 4
+ norm: instance
+ init_type: xavier
+ init_gain: 0.02
+ use_sigmoid: false
+ num_D: 1 #Number of discriminators to use (>1 means multi-scale)
+ get_intermediate_features: false
+ p:
+ <<: *default-dis
+ num_D: 3
+ get_intermediate_features: true
+ use_local_discriminator: false
+ # ttur: false # two time-scale update rule (see SPADE repo)
+ m:
+ <<: *default-dis
+ multi_level: false
+ architecture: base # can be [base | OmniDiscriminator]
+ gan_type: WGAN_norm # can be [GAN | WGAN | WGAN_gp | WGAN_norm]
+ wgan_clamp_lower: -0.01 # used in WGAN, WGAN clap the params in dis to [wgan_clamp_lower, wgan_clamp_upper] for every update
+ wgan_clamp_upper: 0.01 # used in WGAN
+ s:
+ <<: *default-dis
+ gan_type: WGAN_norm # can be [GAN | WGAN | WGAN_gp | WGAN_norm]
+ wgan_clamp_lower: -0.01 # used in WGAN, WGAN clap the params in dis to [wgan_clamp_lower, wgan_clamp_upper] for every update
+ wgan_clamp_upper: 0.01 # used in WGAN
+# -------------------------------
+# ----- Domain Classifier -----
+# -------------------------------
+classifier:
+ opt:
+ optimizer: ExtraAdam # one in [Adam, ExtraAdam] default: Adam
+ beta1: 0.5
+ lr:
+ default: 0.0005
+ lr_policy: step # constant or step ; if step, specify step_size and gamma
+ lr_step_size: 30 # for linear decay
+ lr_gamma: 0.5
+ loss: l2 #Loss can be l1, l2, cross_entropy. default cross_entropy
+ layers: [100, 100, 20, 20, 4] # number of units per hidden layer ; las number is output_dim
+ dropout: 0.4 # probability of being set to 0
+ init_type: kaiming
+ init_gain: 0.2
+ proj_dim: 128 #Dim of projection from latent space
+
+# ------------------------
+# ----- Train Params -----
+# ------------------------
+train:
+ kitti:
+ pretrain: False
+ epochs: 10
+ batch_size: 6
+ amp: False
+ pseudo:
+ tasks: [] # list of tasks for which to use pseudo labels (empty list to disable)
+ epochs: 10 # disable pseudo training after n epochs (set to -1 to never disable)
+ epochs: 300
+ fid:
+ n_images: 57 # val_rf.json has 57 images
+ batch_size: 50 # inception inference batch size, not painter's
+ dims: 2048 # what Inception bock to compute the stats from (see BLOCK_INDEX_BY_DIM in fid.py)
+ latent_domain_adaptation: False # whether or not to do domain adaptation on the latent vectors # Needs to be turned off if use_advent is True
+ lambdas: # scaling factors in the total loss
+ G:
+ d:
+ main: 1
+ gml: 0.5
+ s:
+ crossent: 1
+ crossent_pseudo: 0.001
+ minent: 0.001
+ advent: 0.001
+ m:
+ bce: 1 # Main prediction loss, i.e. GAN or BCE
+ tv: 1 # Total variational loss (for smoothing)
+ gi: 0.05
+ pl4m: 1 # painter loss for the masker (end-to-end)
+ p:
+ context: 0
+ dm: 1 # depth matching
+ featmatch: 10
+ gan: 1 # gan loss
+ reconstruction: 0
+ tv: 0
+ vgg: 10
+ classifier: 1
+ C: 1
+ advent:
+ ent_main: 0.5 # the coefficient of the MinEnt loss that directly minimize the entropy of the image
+ ent_aux: 0.0 # the corresponding coefficient of the MinEnt loss of second output
+ ent_var: 0.1 # the proportion of variance of entropy map in the entropy measure for a certain picture
+ adv_main: 1.0 # the coefficient of the AdvEnt loss that minimize the entropy of the image by adversarial training
+ adv_aux: 0.0 # the corresponding coefficient of the AdvEnt loss of second output
+ dis_main: 1.0 # the discriminator take care of the first output in the adversarial training
+ dis_aux: 0.0 # the discriminator take care of the second output in the adversarial training
+ WGAN_gp: 10 # used in WGAN_gp, it's the hyperparameters for the gradient penalty
+ log_level: 2 # 0: no log, 1: only aggregated losses, >1 detailed losses
+ save_n_epochs: 25 # Save `latest_ckpt.pth` every epoch, `epoch_{epoch}_ckpt.pth` model every n epochs if epoch >= min_save_epoch
+ min_save_epoch: 28 # Save extra intermediate checkpoints when epoch > min_save_epoch
+ resume: false # Load latest_ckpt.pth checkpoint from `output_path` #TODO Make this path of checkpoint to load
+ auto_resume: true # automatically looks for similar output paths and exact same jobID to resume training automatically even if resume is false.
+
+# -----------------------------
+# ----- Validation Params -----
+# -----------------------------
+val:
+ store_images: false # write to disk on top of comet logging
+ val_painter: /miniscratch/_groups/ccai/checkpoints/painter/victor/good_large_lr/checkpoints/latest_ckpt.pth
+# -----------------------------
+# ----- Comet Params ----------
+# -----------------------------
+comet:
+ display_size: 20
+ rows_per_log: 5 # number of samples (rows) in a logged grid image. Number of total logged images: display_size // rows_per_log
+ im_per_row: # how many columns (3 = x, target, pred)
+ p: 4
+ m: 6
+ s: 4
+ d: 4
diff --git a/shared/trainer/events.yaml b/shared/trainer/events.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2cb265620797a5c0ed3b598aab37427dddb245f6
--- /dev/null
+++ b/shared/trainer/events.yaml
@@ -0,0 +1,14 @@
+fire:
+ kernel_size: 281
+ kernel_sigma: 140.5
+ transparency: 200
+ sky_inc_factor: 0.12
+ contrast_factor: 1.5
+ brightness_factor: 0.95
+ crop_bottom_sky_mask: true
+smog:
+ airlight: 0.76
+ beta: 2
+ vr: 1
+ yellow_color: [224, 192, 29]
+ alpha: 20
diff --git a/tests/test_trainer.py b/tests/test_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b84153c169aab557a3aee937ea13429c62a72ec0
--- /dev/null
+++ b/tests/test_trainer.py
@@ -0,0 +1,384 @@
+print("Imports...", end="", flush=True)
+
+import sys
+from pathlib import Path
+
+sys.path.append(str(Path(__file__).resolve().parent.parent))
+
+import atexit
+import logging
+from argparse import ArgumentParser
+from copy import deepcopy
+
+import comet_ml
+import climategan
+from comet_ml.api import API
+from climategan.trainer import Trainer
+from climategan.utils import get_comet_rest_api_key
+
+logging.basicConfig()
+logging.getLogger().setLevel(logging.ERROR)
+import traceback
+
+print("Done.")
+
+
+def set_opts(opts, str_nested_key, value):
+ """
+ Changes an opts with nested keys:
+ set_opts(addict.Dict(), "a.b.c", 2) == Dict({"a":{"b": {"c": 2}}})
+
+ Args:
+ opts (addict.Dict): opts whose values should be changed
+ str_nested_key (str): nested keys joined on "."
+ value (any): value to set to the nested keys of opts
+ """
+ keys = str_nested_key.split(".")
+ o = opts
+ for k in keys[:-1]:
+ o = o[k]
+ o[keys[-1]] = value
+
+
+def set_conf(opts, conf):
+ """
+ Updates opts according to a test scenario's configuration dict.
+ Ignores all keys starting with "__" which are used for the scenario
+ but outside the opts
+
+ Args:
+ opts (addict.Dict): trainer options
+ conf (dict): scenario's configuration
+ """
+ for k, v in conf.items():
+ if k.startswith("__"):
+ continue
+ set_opts(opts, k, v)
+
+
+class bcolors:
+ HEADER = "\033[95m"
+ OKBLUE = "\033[94m"
+ OKGREEN = "\033[92m"
+ WARNING = "\033[93m"
+ FAIL = "\033[91m"
+ ENDC = "\033[0m"
+ BOLD = "\033[1m"
+ UNDERLINE = "\033[4m"
+
+
+class Colors:
+ def _r(self, key, *args):
+ return f"{key}{' '.join(args)}{bcolors.ENDC}"
+
+ def ob(self, *args):
+ return self._r(bcolors.OKBLUE, *args)
+
+ def w(self, *args):
+ return self._r(bcolors.WARNING, *args)
+
+ def og(self, *args):
+ return self._r(bcolors.OKGREEN, *args)
+
+ def f(self, *args):
+ return self._r(bcolors.FAIL, *args)
+
+ def b(self, *args):
+ return self._r(bcolors.BOLD, *args)
+
+ def u(self, *args):
+ return self._r(bcolors.UNDERLINE, *args)
+
+
+def comet_handler(exp, api):
+ def sub_handler():
+ p = Colors()
+ print()
+ print(p.b(p.w("Deleting comet experiment")))
+ api.delete_experiment(exp.get_key())
+
+ return sub_handler
+
+
+def print_start(desc):
+ p = Colors()
+ cdesc = p.b(p.ob(desc))
+ title = "| " + cdesc + " |"
+ line = "-" * (len(desc) + 6)
+ print(f"{line}\n{title}\n{line}")
+
+
+def print_end(desc=None, ok=None):
+ p = Colors()
+ if ok and desc is None:
+ desc = "Done"
+ cdesc = p.b(p.og(desc))
+ elif not ok and desc is None:
+ desc = "! Fail !"
+ cdesc = p.b(p.f(desc))
+ elif desc is not None:
+ cdesc = p.b(p.og(desc))
+ else:
+ desc = "Unknown"
+ cdesc = desc
+
+ title = "| " + cdesc + " |"
+ line = "-" * (len(desc) + 6)
+ print(f"{line}\n{title}\n{line}\n")
+
+
+def delete_on_exit(exp):
+ """
+ Registers a callback to delete the comet exp at program exit
+
+ Args:
+ exp (comet_ml.Experiment): The exp to delete
+ """
+ rest_api_key = get_comet_rest_api_key()
+ api = API(api_key=rest_api_key)
+ atexit.register(comet_handler(exp, api))
+
+
+if __name__ == "__main__":
+
+ # -----------------------------
+ # ----- Parse Arguments -----
+ # -----------------------------
+ parser = ArgumentParser()
+ parser.add_argument("--no_delete", action="store_true", default=False)
+ parser.add_argument("--no_end_to_end", action="store_true", default=False)
+ parser.add_argument("--include", "-i", nargs="+", default=[])
+ parser.add_argument("--exclude", "-e", nargs="+", default=[])
+ args = parser.parse_args()
+
+ assert not (args.include and args.exclude), "Choose 1: include XOR exclude"
+
+ include = set(int(i) for i in args.include)
+ exclude = set(int(i) for i in args.exclude)
+ if include:
+ print("Including exclusively tests", " ".join(args.include))
+ if exclude:
+ print("Excluding tests", " ".join(args.exclude))
+
+ # --------------------------------------
+ # ----- Create global experiment -----
+ # --------------------------------------
+ print("Creating comet Experiment...", end="", flush=True)
+ global_exp = comet_ml.Experiment(
+ project_name="climategan-test", display_summary_level=0
+ )
+ print("Done.")
+
+ if not args.no_delete:
+ delete_on_exit(global_exp)
+
+ # prompt util for colors
+ prompt = Colors()
+
+ # -------------------------------------
+ # ----- Base Test Scenario Opts -----
+ # -------------------------------------
+ print("Loading opts...", end="", flush=True)
+ base_opts = climategan.utils.load_opts()
+ base_opts.data.check_samples = False
+ base_opts.train.fid.n_images = 5
+ base_opts.comet.display_size = 5
+ base_opts.tasks = ["m", "s", "d"]
+ base_opts.domains = ["r", "s"]
+ base_opts.data.loaders.num_workers = 4
+ base_opts.data.loaders.batch_size = 2
+ base_opts.data.max_samples = 9
+ base_opts.train.epochs = 1
+ if isinstance(base_opts.data.transforms[-1].new_size, int):
+ base_opts.data.transforms[-1].new_size = 256
+ else:
+ base_opts.data.transforms[-1].new_size.default = 256
+ print("Done.")
+
+ # --------------------------------------
+ # ----- Configure Test Scenarios -----
+ # --------------------------------------
+
+ # override any nested key in opts
+ # create scenario-specific variables with __key
+ # ALWAYS specify a __doc key to describe your scenario
+ test_scenarios = [
+ {"__use_comet": False, "__doc": "MSD no exp", "__verbose": 1}, # 0
+ {"__doc": "MSD with exp"}, # 1
+ {
+ "__doc": "MSD no exp upsample_featuremaps", # 2
+ "__use_comet": False,
+ "gen.d.upsample_featuremaps": True,
+ "gen.s.upsample_featuremaps": True,
+ },
+ {"tasks": ["p"], "domains": ["rf"], "__doc": "Painter"}, # 3
+ {
+ "__doc": "M no exp low level feats", # 4
+ "__use_comet": False,
+ "gen.m.use_low_level_feats": True,
+ "gen.m.use_dada": False,
+ "tasks": ["m"],
+ },
+ {
+ "__doc": "MSD no exp deeplabv2", # 5
+ "__use_comet": False,
+ "gen.encoder.architecture": "deeplabv2",
+ "gen.s.architecture": "deeplabv2",
+ },
+ {
+ "__doc": "MSDP no End-to-end", # 6
+ "domains": ["rf", "r", "s"],
+ "tasks": ["m", "s", "d", "p"],
+ },
+ {
+ "__doc": "MSDP inference only no exp", # 7
+ "__inference": True,
+ "__use_comet": False,
+ "domains": ["rf", "r", "s"],
+ "tasks": ["m", "s", "d", "p"],
+ },
+ {
+ "__doc": "MSDP with End-to-end", # 8
+ "__pl4m": True,
+ "domains": ["rf", "r", "s"],
+ "tasks": ["m", "s", "d", "p"],
+ },
+ {
+ "__doc": "Kitti pretrain", # 9
+ "train.epochs": 2,
+ "train.kitti.pretrain": True,
+ "train.kitti.epochs": 1,
+ "domains": ["kitti", "r", "s"],
+ "train.kitti.batch_size": 2,
+ },
+ {"__doc": "Depth Dada archi", "gen.d.architecture": "dada"}, # 10
+ {
+ "__doc": "Depth Base archi",
+ "gen.d.architecture": "base",
+ "gen.m.use_dada": False,
+ "gen.s.use_dada": False,
+ }, # 11
+ {
+ "__doc": "Depth Base Classification", # 12
+ "gen.d.architecture": "base",
+ "gen.d.classify.enable": True,
+ "gen.m.use_dada": False,
+ "gen.s.use_dada": False,
+ },
+ {
+ "__doc": "MSD Resnet V3+ backbone",
+ "gen.deeplabv3.backbone": "resnet",
+ }, # 13
+ {
+ "__use_comet": False,
+ "__doc": "MSD SPADE 12 (without x)",
+ "__verbose": 1,
+ "gen.m.use_spade": True,
+ "gen.m.spade.cond_nc": 12,
+ }, # 14
+ {
+ "__use_comet": False,
+ "__doc": "MSD SPADE 15 (with x)",
+ "__verbose": 1,
+ "gen.m.use_spade": True,
+ "gen.m.spade.cond_nc": 15,
+ }, # 15
+ {
+ "__use_comet": False,
+ "__doc": "Painter With Diff Augment",
+ "__verbose": 1,
+ "domains": ["rf"],
+ "tasks": ["p"],
+ "gen.p.diff_aug.use": True,
+ }, # 15
+ {
+ "__use_comet": False,
+ "__doc": "MSD DADA_s",
+ "__verbose": 1,
+ "gen.s.use_dada": True,
+ "gen.m.use_dada": False,
+ }, # 16
+ {
+ "__use_comet": False,
+ "__doc": "MSD DADA_ms",
+ "__verbose": 1,
+ "gen.s.use_dada": True,
+ "gen.m.use_dada": True,
+ }, # 17
+ ]
+
+ n_confs = len(test_scenarios)
+
+ fails = []
+ successes = []
+
+ # --------------------------------
+ # ----- Run Test Scenarios -----
+ # --------------------------------
+
+ for test_idx, conf in enumerate(test_scenarios):
+ if test_idx in exclude or (include and test_idx not in include):
+ reason = (
+ "because it is in exclude"
+ if test_idx in exclude
+ else "because it is not in include"
+ )
+ print("Ignoring test", test_idx, reason)
+ continue
+
+ # copy base scenario opts
+ test_opts = deepcopy(base_opts)
+ # update with scenario configuration
+ set_conf(test_opts, conf)
+
+ # print scenario description
+ print_start(
+ f"[{test_idx}/{n_confs - 1}] "
+ + conf.get("__doc", "WARNING: no __doc for test scenario")
+ )
+ print()
+
+ comet = conf.get("__use_comet", True)
+ pl4m = conf.get("__pl4m", False)
+ inference = conf.get("__inference", False)
+ verbose = conf.get("__verbose", 0)
+
+ # set (or not) experiment
+ test_exp = None
+ if comet:
+ test_exp = global_exp
+
+ try:
+ # create trainer
+ trainer = Trainer(
+ opts=test_opts,
+ verbose=verbose,
+ comet_exp=test_exp,
+ )
+ trainer.functional_test_mode()
+
+ # set (or not) painter loss for masker (= end-to-end)
+ if pl4m:
+ trainer.use_pl4m = True
+
+ # test training procedure
+ trainer.setup(inference=inference)
+ if not inference:
+ trainer.train()
+
+ successes.append(test_idx)
+ ok = True
+ except Exception as e:
+ print(e)
+ print(traceback.format_exc())
+ fails.append(test_idx)
+ ok = False
+ finally:
+ print_end(ok=ok)
+
+ print_end(desc=" ----- Summary ----- ")
+ if len(fails) == 0:
+ print("•• All scenarios were successful")
+ else:
+ print(f"•• {len(successes)}/{len(test_scenarios)} successful tests")
+ print(f"•• Failed test indices: {', '.join(map(str, fails))}")
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b0377c808df36b60534784596e37088ba236acb
--- /dev/null
+++ b/train.py
@@ -0,0 +1,195 @@
+import logging
+import os
+from pathlib import Path
+from time import sleep, time
+
+import hydra
+import yaml
+from addict import Dict
+from comet_ml import ExistingExperiment, Experiment
+from omegaconf import OmegaConf
+
+from climategan.trainer import Trainer
+from climategan.utils import (
+ comet_kwargs,
+ copy_run_files,
+ env_to_path,
+ find_existing_training,
+ flatten_opts,
+ get_existing_comet_id,
+ get_git_branch,
+ get_git_revision_hash,
+ get_increased_path,
+ kill_job,
+ load_opts,
+ pprint,
+)
+
+logging.basicConfig()
+logging.getLogger().setLevel(logging.ERROR)
+
+hydra_config_path = Path(__file__).resolve().parent / "shared/trainer/config.yaml"
+
+
+# requires hydra-core==0.11.3 and omegaconf==1.4.1
+@hydra.main(config_path=hydra_config_path, strict=False)
+def main(opts):
+ """
+ Opts prevalence:
+ 1. Load file specified in args.default (or shared/trainer/defaults.yaml
+ if none is provided)
+ 2. Update with file specified in args.config (or no update if none is provided)
+ 3. Update with parsed command-line arguments
+
+ e.g.
+ `python train.py args.config=config/large-lr.yaml data.loaders.batch_size=10`
+ loads defaults, overrides with values in large-lr.yaml and sets batch_size to 10
+ """
+
+ # -----------------------------
+ # ----- Parse arguments -----
+ # -----------------------------
+
+ hydra_opts = Dict(OmegaConf.to_container(opts))
+ args = hydra_opts.pop("args", None)
+ auto_resumed = {}
+
+ config_path = args.config
+
+ if hydra_opts.train.resume:
+ out_ = str(env_to_path(hydra_opts.output_path))
+ config_path = Path(out_) / "opts.yaml"
+ if not config_path.exists():
+ config_path = None
+ print("WARNING: could not reuse the opts in {}".format(out_))
+
+ default = args.default or Path(__file__).parent / "shared/trainer/defaults.yaml"
+
+ # -----------------------
+ # ----- Load opts -----
+ # -----------------------
+
+ opts = load_opts(config_path, default=default, commandline_opts=hydra_opts)
+ if args.resume:
+ opts.train.resume = True
+
+ opts.jobID = os.environ.get("SLURM_JOBID")
+ opts.slurm_partition = os.environ.get("SLURM_JOB_PARTITION")
+ opts.output_path = str(env_to_path(opts.output_path))
+ print("Config output_path:", opts.output_path)
+
+ exp = comet_previous_id = None
+
+ # -------------------------------
+ # ----- Check output_path -----
+ # -------------------------------
+
+ # Auto-continue if same slurm job ID (=job was requeued)
+ if not opts.train.resume and opts.train.auto_resume:
+ print("\n\nTrying to auto-resume...")
+ existing_path = find_existing_training(opts)
+ if existing_path is not None and existing_path.exists():
+ auto_resumed["original output_path"] = str(opts.output_path)
+ auto_resumed["existing_path"] = str(existing_path)
+ opts.train.resume = True
+ opts.output_path = str(existing_path)
+
+ # Still not resuming: creating new output path
+ if not opts.train.resume:
+ opts.output_path = str(get_increased_path(opts.output_path))
+ Path(opts.output_path).mkdir(parents=True, exist_ok=True)
+
+ # Copy the opts's sbatch_file to output_path
+ copy_run_files(opts)
+ # store git hash
+ opts.git_hash = get_git_revision_hash()
+ opts.git_branch = get_git_branch()
+
+ if not args.no_comet:
+ # ----------------------------------
+ # ----- Set Comet Experiment -----
+ # ----------------------------------
+
+ if opts.train.resume:
+ # Is resuming: get existing comet exp id
+ assert Path(opts.output_path).exists(), "Output_path does not exist"
+
+ comet_previous_id = get_existing_comet_id(opts.output_path)
+ # Continue existing experiment
+ if comet_previous_id is None:
+ print("WARNING could not retreive previous comet id")
+ print(f"from {opts.output_path}")
+ else:
+ print("Continuing previous experiment", comet_previous_id)
+ auto_resumed["continuing exp id"] = comet_previous_id
+ exp = ExistingExperiment(
+ previous_experiment=comet_previous_id, **comet_kwargs
+ )
+ print("Comet Experiment resumed")
+
+ if exp is None:
+ # Create new experiment
+ print("Starting new experiment")
+ exp = Experiment(project_name="climategan", **comet_kwargs)
+ exp.log_asset_folder(
+ str(Path(__file__).parent / "climategan"),
+ recursive=True,
+ log_file_name=True,
+ )
+ exp.log_asset(str(Path(__file__)))
+
+ # Log note
+ if args.note:
+ exp.log_parameter("note", args.note)
+
+ # Merge and log tags
+ if args.comet_tags or opts.comet.tags:
+ tags = set([f"branch:{opts.git_branch}"])
+ if args.comet_tags:
+ tags.update(args.comet_tags)
+ if opts.comet.tags:
+ tags.update(opts.comet.tags)
+ opts.comet.tags = list(tags)
+ print("Logging to comet.ml with tags", opts.comet.tags)
+ exp.add_tags(opts.comet.tags)
+
+ # Log all opts
+ exp.log_parameters(flatten_opts(opts))
+ if auto_resumed:
+ exp.log_text("\n".join(f"{k:20}: {v}" for k, v in auto_resumed.items()))
+
+ # allow some time for comet to get its url
+ sleep(1)
+
+ # Save comet exp url
+ url_path = get_increased_path(Path(opts.output_path) / "comet_url.txt")
+ with open(url_path, "w") as f:
+ f.write(exp.url)
+
+ # Save config file
+ opts_path = get_increased_path(Path(opts.output_path) / "opts.yaml")
+ with (opts_path).open("w") as f:
+ yaml.safe_dump(opts.to_dict(), f)
+
+ pprint("Running model in", opts.output_path)
+
+ # -------------------
+ # ----- Train -----
+ # -------------------
+
+ trainer = Trainer(opts, comet_exp=exp, verbose=1)
+ trainer.logger.time.start_time = time()
+ trainer.setup()
+ trainer.train()
+
+ # -----------------------------
+ # ----- End of training -----
+ # -----------------------------
+
+ pprint("Done training")
+ kill_job(opts.jobID)
+
+
+if __name__ == "__main__":
+
+ main()
diff --git a/utils_scripts/compare_maskers.py b/utils_scripts/compare_maskers.py
new file mode 100644
index 0000000000000000000000000000000000000000..606fd06c653d748244623a7f353f8deb7865a935
--- /dev/null
+++ b/utils_scripts/compare_maskers.py
@@ -0,0 +1,344 @@
+import sys
+from argparse import ArgumentParser
+from pathlib import Path
+from comet_ml import Experiment
+
+import numpy as np
+import torch
+import yaml
+from PIL import Image
+from skimage.color import gray2rgb
+from skimage.io import imread
+from skimage.transform import resize
+from skimage.util import img_as_ubyte
+from tqdm import tqdm
+
+sys.path.append(str(Path(__file__).resolve().parent.parent))
+
+import climategan
+
+GROUND_MODEL = "/miniscratch/_groups/ccai/experiments/runs/ablation-v1/out--ground"
+
+
+def uint8(array):
+ return array.astype(np.uint8)
+
+
+def crop_and_resize(image_path, label_path):
+ """
+ Resizes an image so that it keeps the aspect ratio and the smallest dimensions
+ is 640, then crops this resized image in its center so that the output is 640x640
+ without aspect ratio distortion
+
+ Args:
+ image_path (Path or str): Path to an image
+ label_path (Path or str): Path to the image's associated label
+
+ Returns:
+ tuple((np.ndarray, np.ndarray)): (new image, new label)
+ """
+
+ img = imread(image_path)
+ lab = imread(label_path)
+
+ # if img.shape[-1] == 4:
+ # img = uint8(rgba2rgb(img) * 255)
+
+ # TODO: remove (debug)
+ if img.shape[:2] != lab.shape[:2]:
+ print(
+ "\nWARNING: shape mismatch: im -> {}, lab -> {}".format(
+ image_path.name, label_path.name
+ )
+ )
+ # breakpoint()
+
+ # resize keeping aspect ratio: smallest dim is 640
+ h, w = img.shape[:2]
+ if h < w:
+ size = (640, int(640 * w / h))
+ else:
+ size = (int(640 * h / w), 640)
+
+ r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
+ r_img = uint8(r_img)
+
+ r_lab = resize(lab, size, preserve_range=True, anti_aliasing=False, order=0)
+ r_lab = uint8(r_lab)
+
+ # crop in the center
+ H, W = r_img.shape[:2]
+
+ top = (H - 640) // 2
+ left = (W - 640) // 2
+
+ rc_img = r_img[top : top + 640, left : left + 640, :]
+ rc_lab = (
+ r_lab[top : top + 640, left : left + 640, :]
+ if r_lab.ndim == 3
+ else r_lab[top : top + 640, left : left + 640]
+ )
+
+ return rc_img, rc_lab
+
+
+def load_ground(ground_output_path, ref_image_path):
+ gop = Path(ground_output_path)
+ rip = Path(ref_image_path)
+
+ ground_paths = list((gop / "eval-metrics" / "pred").glob(f"{rip.stem}.jpg")) + list(
+ (gop / "eval-metrics" / "pred").glob(f"{rip.stem}.png")
+ )
+ if len(ground_paths) == 0:
+ raise ValueError(
+ f"Could not find a ground match in {str(gop)} for image {str(rip)}"
+ )
+ elif len(ground_paths) > 1:
+ raise ValueError(
+ f"Found more than 1 ground match in {str(gop)} for image {str(rip)}:"
+ + f" {list(map(str, ground_paths))}"
+ )
+ ground_path = ground_paths[0]
+ _, ground = crop_and_resize(rip, ground_path)
+ ground = (ground > 0).astype(np.float32)
+ return torch.from_numpy(ground).unsqueeze(0).unsqueeze(0).cuda()
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument("-y", "--yaml", help="Path to a list of models")
+ parser.add_argument(
+ "--disable_loading",
+ action="store_true",
+ default=False,
+ help="Disable loading of existing inferences",
+ )
+ parser.add_argument(
+ "-t", "--tags", nargs="*", help="Comet.ml tags", default=[], type=str
+ )
+ parser.add_argument(
+ "--tasks",
+ nargs="*",
+ help="Comet.ml tags",
+ default=["x", "d", "s", "m", "mx", "p"],
+ type=str,
+ )
+ args = parser.parse_args()
+
+ print("Received args:")
+ print(vars(args))
+
+ return args
+
+
+def load_images_and_labels(
+ path="/miniscratch/_groups/ccai/data/omnigan/masker-test-set",
+):
+ p = Path(path)
+ ims_path = p / "imgs"
+ lab_path = p / "labels"
+
+ ims = sorted(climategan.utils.find_images(ims_path), key=lambda x: x.name)
+ labs = sorted(
+ climategan.utils.find_images(lab_path),
+ key=lambda x: x.name.replace("_labeled.", "."),
+ )
+
+ xs = climategan.transforms.PrepareInference()(ims)
+ ys = climategan.transforms.PrepareInference(is_label=True)(labs)
+
+ return xs, ys, ims, labs
+
+
+def load_inferences(inf_path, im_paths):
+ try:
+ assert inf_path.exists()
+ assert sorted([i.stem for i in im_paths]) == sorted(
+ [i.stem for i in inf_path.glob("*.pt")]
+ )
+ return [torch.load(str(i)) for i in tqdm(list(inf_path.glob("*.pt")))]
+ except Exception as e:
+ print()
+ print(e)
+ print("Aborting Loading")
+ print()
+ return None
+
+
+def get_or_load_inferences(
+ m_path, device, xs, is_ground, im_paths, ground_model, try_load=True
+):
+ inf_path = Path(m_path) / "inferences"
+ if try_load:
+ print("Trying to load existing inferences:")
+ outputs = load_inferences(inf_path, im_paths)
+ if outputs is not None:
+ print("Successfully loaded existing inferences")
+ return outputs
+
+ trainer = climategan.trainer.Trainer.resume_from_path(
+ m_path if not is_ground else ground_model,
+ inference=True,
+ new_exp=None,
+ device=device,
+ )
+
+ inf_path.mkdir(exist_ok=True)
+ outputs = []
+ for i, x in enumerate(tqdm(xs)):
+ x = x.to(trainer.device)
+ if not is_ground:
+ out = trainer.G.decode(x=x)
+ else:
+ out = {"m": load_ground(GROUND_MODEL, im_paths[i])}
+ out["p"] = trainer.G.paint(out["m"] > 0.5, x)
+ out["x"] = x
+ inference = {k: v.cpu() for k, v in out.items()}
+ outputs.append(inference)
+ torch.save(inference, inf_path / f"{im_paths[i].stem}.pt")
+ print()
+
+ return outputs
+
+
+def numpify(outputs):
+ nps = []
+ print("Numpifying...")
+ for o in tqdm(outputs):
+ x = (o["x"][0].permute(1, 2, 0).numpy() + 1) / 2
+ m = o["m"]
+ m = (m[0, 0, :, :].numpy() > 0.5).astype(np.uint8)
+ p = (o["p"][0].permute(1, 2, 0).numpy() + 1) / 2
+ data = {"m": m, "p": p, "x": x}
+ if "s" in o:
+ s = climategan.data.decode_segmap_merged_labels(o["s"], "r", False) / 255.0
+ data["s"] = s[0].permute(1, 2, 0).numpy()
+ if "d" in o:
+ d = climategan.tutils.normalize_tensor(o["d"]).squeeze().numpy()
+ data["d"] = d
+ nps.append({k: img_as_ubyte(v) for k, v in data.items()})
+ return nps
+
+
+def concat_npy_for_model(data, tasks):
+ assert "m" in data
+ assert "x" in data
+ assert "p" in data
+
+ x = mask = depth = seg = painted = masked = None
+
+ x = data["x"]
+ painted = data["p"]
+ mask = (gray2rgb(data["m"]) * 255).astype(np.uint8)
+ painted = data["p"]
+ masked = (1 - gray2rgb(data["m"])) * x
+
+ concats = []
+
+ if "d" in data:
+ depth = img_as_ubyte(
+ gray2rgb(
+ resize(data["d"], data["x"].shape[:2], anti_aliasing=True, order=1)
+ )
+ )
+ else:
+ depth = np.ones_like(data["x"]) * 255
+
+ if "s" in data:
+ seg = img_as_ubyte(
+ resize(data["s"], data["x"].shape[:2], anti_aliasing=False, order=0)
+ )
+ else:
+ seg = np.ones_like(data["x"]) * 255
+
+ for t in tasks:
+ if t == "x":
+ concats.append(x)
+ if t == "m":
+ concats.append(mask)
+ elif t == "mx":
+ concats.append(masked)
+ elif t == "d":
+ concats.append(depth)
+ elif t == "s":
+ concats.append(seg)
+ elif t == "p":
+ concats.append(painted)
+
+ row = np.concatenate(concats, axis=1)
+
+ return row
+
+
+if __name__ == "__main__":
+ args = parse_args()
+
+ with open(args.yaml, "r") as f:
+ maskers = yaml.safe_load(f)
+ if "models" in maskers:
+ maskers = maskers["models"]
+
+ load = not args.disable_loading
+ tags = args.tags
+ tasks = args.tasks
+
+ ground_model = None
+ for m in maskers:
+ if "ground" not in maskers:
+ ground_model = m
+ break
+ if ground_model is None:
+ raise ValueError("Could not find a non-ground model to get a painter")
+
+ device = torch.device("cuda:0")
+ torch.set_grad_enabled(False)
+
+ xs, ys, im_paths, lab_paths = load_images_and_labels()
+
+ np_outs = {}
+ names = []
+
+ for m_path in maskers:
+
+ opt_path = Path(m_path) / "opts.yaml"
+ with opt_path.open("r") as f:
+ opt = yaml.safe_load(f)
+
+ name = (
+ ", ".join(
+ [
+ t
+ for t in sorted(opt["comet"]["tags"])
+ if "branch" not in t and "ablation" not in t and "trash" not in t
+ ]
+ )
+ if "--ground" not in m_path
+ else "ground"
+ )
+ names.append(name)
+
+ is_ground = name == "ground"
+
+ print("#" * 100)
+ print("\n>>> Processing", name)
+ print()
+
+ outputs = get_or_load_inferences(
+ m_path, device, xs, is_ground, im_paths, ground_model, load
+ )
+ nps = numpify(outputs)
+
+ np_outs[name] = nps
+
+ exp = Experiment(project_name="climategan-inferences", display_summary_level=0)
+ exp.log_parameter("names", names)
+ exp.add_tags(tags)
+
+ for i in tqdm(range(len(xs))):
+ all_models_for_image = []
+ for name in names:
+ xpmds = concat_npy_for_model(np_outs[name][i], tasks)
+ all_models_for_image.append(xpmds)
+ full_im = np.concatenate(all_models_for_image, axis=0)
+ pil_im = Image.fromarray(full_im)
+ exp.log_image(pil_im, name=im_paths[i].stem.replace(".", "_"), step=i)
diff --git a/utils_scripts/create_labeled.py b/utils_scripts/create_labeled.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bf0d02b74a67dd1cace6e0a4ffe778b59ac7f66
--- /dev/null
+++ b/utils_scripts/create_labeled.py
@@ -0,0 +1,25 @@
+from pathlib import Path
+from skimage.io import imread, imsave
+import numpy as np
+
+if __name__ == "__main__":
+ impath = Path("/Users/victor/Downloads/metrics-v2/imgs")
+ labpath = Path("/Users/victor/Downloads/metrics-v2/labels")
+ outpath = Path("/Users/victor/Downloads/metrics-v2/labeled")
+ outpath.mkdir(exist_ok=True, parents=True)
+ ims = sorted(
+ [d for d in impath.iterdir() if d.is_file() and not d.name.startswith(".")],
+ key=lambda x: x.stem,
+ )
+ labs = sorted(
+ [d for d in labpath.iterdir() if d.is_file() and not d.name.startswith(".")],
+ key=lambda x: x.stem.replace("_labeled", ""),
+ )
+
+ for k, (i, l) in enumerate(zip(ims, labs)):
+ print(f"{k + 1} / {len(ims)}", end="\r", flush=True)
+ assert i.stem == l.stem.replace("_labeled", "")
+ im = imread(i)[:, :, :3]
+ la = imread(l)
+ ld = (0.7 * im + 0.3 * la).astype(np.uint8)
+ imsave(outpath / i.name, ld)
diff --git a/utils_scripts/download_comet_images.py b/utils_scripts/download_comet_images.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6029f7097be1937b0d0f9a8faec8ea631db5136
--- /dev/null
+++ b/utils_scripts/download_comet_images.py
@@ -0,0 +1,311 @@
+import argparse
+import os
+from collections import Counter
+from pathlib import Path
+
+import comet_ml
+import yaml
+from addict import Dict
+from comet_ml import config
+
+
+def parse_tags(tags_str):
+ all_tags = set(t.strip() for t in tags_str.split(","))
+ keep_tags = set()
+ remove_tags = set()
+ for t in all_tags:
+ if "!" in t or "~" in t:
+ remove_tags.add(t[1:])
+ else:
+ keep_tags.add(t)
+ return all_tags, keep_tags, remove_tags
+
+
+def select_lambdas(vars):
+ """
+ Create a specific file with the painter's lambdas
+
+ Args:
+ vars (dict): output of locals()
+ """
+ opts = vars["opts"]
+ dev = vars["args"].dev
+ lambdas = opts.train.lambdas.G.p
+ if not dev:
+ with open("./painter_lambdas.yaml", "w") as f:
+ yaml.safe_dump(lambdas.to_dict(), f)
+
+
+def parse_value(v: str):
+ """
+ Parses a string into bool or list or int or float or returns it as is
+
+ Args:
+ v (str): value to parse
+
+ Returns:
+ any: parsed value
+ """
+ if v.lower() == "false":
+ return False
+ if v.lower() == "true":
+ return True
+ if v.startswith("[") and v.endswith("]"):
+ return [
+ parse_value(sub_v)
+ for sub_v in v.replace("[", "").replace("]", "").split(", ")
+ ]
+ if "." in v:
+ try:
+ vv = float(v)
+ return vv
+ except ValueError:
+ return v
+ else:
+ try:
+ vv = int(v)
+ return vv
+ except ValueError:
+ return v
+
+
+def parse_opts(summary):
+ """
+ Parses a flatten_opts summary into an addict.Dict
+
+ Args:
+ summary (list(dict)): List of dicts from exp.get_parameters_summary()
+
+ Returns:
+ addict.Dict: parsed exp params
+ """
+ opts = Dict()
+ for item in summary:
+ k, v = item["name"], parse_value(item["valueCurrent"])
+ if "." in k:
+ d = opts
+ for subkey in k.split(".")[:-1]:
+ d = d[subkey]
+ d[k.split(".")[-1]] = v
+ else:
+ opts[k] = v
+ return opts
+
+
+def has_right_tags(exp: comet_ml.Experiment, keep: set, remove: set) -> bool:
+ """
+ All the "keep" tags should be in the experiment's tags
+ None of the "remove" tags should be in the experiment's tags.
+
+ Args:
+ exp (comet_ml.Experiment): experiment to select (or not)
+ keep (set): tags the exp should have
+ remove (set): tags the exp cannot have
+
+ Returns:
+ bool: should this exp be selected
+ """
+ tags = set(exp.get_tags())
+ has_all_keep = keep.intersection(tags) == keep
+ has_any_remove = remove.intersection(tags)
+ return has_all_keep and not has_any_remove
+
+
+if __name__ == "__main__":
+ # ------------------------
+ # ----- Parse args -----
+ # ------------------------
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-e", "--exp_id", type=str, default="")
+ parser.add_argument(
+ "-d",
+ "--download_dir",
+ type=str,
+ default=None,
+ help="Where to download the images",
+ )
+ parser.add_argument(
+ "-s", "--step", default="last", type=str, help="`last`, `all` or `int`"
+ )
+ parser.add_argument(
+ "-b",
+ "--base_dir",
+ default="./",
+ type=str,
+ help="if download_dir is not specified, download into base_dir/exp_id[:8]/",
+ )
+ parser.add_argument(
+ "-t",
+ "--tags",
+ default="",
+ type=str,
+ help="download all images of all with a set of tags",
+ )
+ parser.add_argument(
+ "-i",
+ "--id_length",
+ default=8,
+ type=int,
+ help="Length of the experiment's ID substring to make dirs: exp.id[:id_length]",
+ )
+ parser.add_argument(
+ "--dev",
+ default=False,
+ action="store_true",
+ help="dry run: no mkdir, no download",
+ )
+ parser.add_argument(
+ "-p",
+ "--post_processings",
+ default="",
+ type=str,
+ help="comma separated string list of post processing functions to apply",
+ )
+ parser.add_argument(
+ "-r",
+ "--running",
+ default=False,
+ action="store_true",
+ help="only select running exps",
+ )
+ args = parser.parse_args()
+ print(args)
+
+ # -------------------------------------
+ # ----- Create post processings -----
+ # -------------------------------------
+
+ POST_PROCESSINGS = {"select_lambdas": select_lambdas}
+ post_processes = list(
+ filter(
+ lambda p: p is not None,
+ [POST_PROCESSINGS.get(k.strip()) for k in args.post_processings.split(",")],
+ )
+ )
+
+ # ------------------------------------------------------
+ # ----- Create Download Dir from download_dir or -----
+ # ----- base_dir/exp_id[:args.id_length] -----
+ # ------------------------------------------------------
+
+ download_dir = Path(args.download_dir or Path(args.base_dir)).resolve()
+ if not args.dev:
+ download_dir.mkdir(parents=True, exist_ok=True)
+
+ # ------------------------
+ # ----- Check step -----
+ # ------------------------
+
+ step = None
+ try:
+ step = int(args.step)
+ except ValueError:
+ step = args.step
+ assert step in {"last", "all"}
+
+ api = comet_ml.api.API()
+
+ # ---------------------------------------
+ # ----- Select exps based on tags -----
+ # ---------------------------------------
+ if not args.tags:
+ assert args.exp_id
+ exps = [api.get_experiment_by_id(args.exp_id)]
+ else:
+ all_tags, keep_tags, remove_tags = parse_tags(args.tags)
+ download_dir = download_dir / "&".join(sorted(all_tags))
+
+ print("Selecting experiments with tags", all_tags)
+ conf = dict(config.get_config())
+ exps = api.get_experiments(
+ workspace=conf.get("comet.workspace"),
+ project_name=conf.get("comet.project_name") or "climategan",
+ )
+ exps = list(filter(lambda e: has_right_tags(e, keep_tags, remove_tags), exps))
+ if args.running:
+ exps = [e for e in exps if e.alive]
+
+ # -------------------------
+ # ----- Print setup -----
+ # -------------------------
+
+ print(
+ "Processing {} experiments in {} with post processes {}".format(
+ len(exps), str(download_dir), post_processes
+ )
+ )
+ assert all(
+ [v == 1 for v in Counter([e.id[: args.id_length] for e in exps]).values()]
+ ), "Experiment ID conflict, use a larger --id_length"
+
+ for e, exp in enumerate(exps):
+ # ----------------------------------------------
+ # ----- Setup Current Download Directory -----
+ # ----------------------------------------------
+ cropped_id = exp.id[: args.id_length]
+ ddir = (download_dir / cropped_id).resolve()
+ if not args.dev:
+ ddir.mkdir(parents=True, exist_ok=True)
+
+ # ------------------------------
+ # ----- Fetch image list -----
+ # ------------------------------
+ ims = [asset for asset in exp.get_asset_list() if asset["image"] is True]
+
+ # -----------------------------------
+ # ----- Filter images by step -----
+ # -----------------------------------
+
+ if step == "last":
+ curr_step = max(i["step"] or -1 for i in ims)
+ if curr_step == -1:
+ curr_step = None
+ else:
+ curr_step = step
+
+ ims = [i for i in ims if (i["step"] == curr_step) or (step == "all")]
+
+ ddir = ddir / str(curr_step)
+ if not args.dev:
+ ddir.mkdir(parents=True, exist_ok=True)
+
+ # ----------------------------------------------
+ # ----- Store experiment's link and opts -----
+ # ----------------------------------------------
+ summary = exp.get_parameters_summary()
+ opts = parse_opts(summary)
+ if not args.dev:
+ with open("./url.txt", "w") as f:
+ f.write(exp.url)
+ with open("./opts.yaml", "w") as f:
+ yaml.safe_dump(opts.to_dict(), f)
+
+ # ------------------------------------------
+ # ----- Download png files with curl -----
+ # ------------------------------------------
+ print(
+ " >>> Downloading exp {}'s image at step `{}` into {}".format(
+ cropped_id, args.step, str(ddir)
+ )
+ )
+
+ for i, im in enumerate(ims):
+ if not Path(im["fileName"] + "_{}.png".format(curr_step)).exists():
+ print(
+ "\nDownloading exp {}/{} image {}/{}: {} in {}".format(
+ e + 1, len(exps), i + 1, len(ims), im["fileName"], ddir
+ )
+ )
+ if not args.dev:
+ assert len(im["curlDownload"].split(" > ")) == 2
+ curl_command = im["curlDownload"].split(" > ")[0]
+ file_stem = Path(im["curlDownload"].split(" > ")[1]).stem
+
+ file_path = (
+ f'"{str(ddir / file_stem)}_{cropped_id}_{curr_step}.png"'
+ )
+
+ signal = os.system(f"{curl_command} > {file_path}")
+ for p in post_processes:
+ p(locals())
diff --git a/utils_scripts/download_labelbox.py b/utils_scripts/download_labelbox.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbfb27f59fcd893ec183a0756ec8d4e68fe1ef13
--- /dev/null
+++ b/utils_scripts/download_labelbox.py
@@ -0,0 +1,41 @@
+import json
+import os
+from pathlib import Path
+
+if __name__ == "__main__":
+ # labelbox json export path
+ path = "/Users/victor/Downloads/export-2021-02-27T17_15_30.291Z.json"
+ # where to write the downloaded images
+ out = Path("/Users/victor/Downloads/labelbox_test_flood-v2")
+ # create out dir
+ out.mkdir(exist_ok=True, parents=True)
+
+ # load export data
+ with open(path, "r") as f:
+ data = json.load(f)
+
+ for i, d in enumerate(data):
+ # find all polygons
+ objects = d["Label"]["objects"]
+ # retrieve original image name
+ name = d["External ID"]
+ stem = Path(name).stem
+ # output dir for current image
+ m_out = out / stem[:30]
+ m_out.mkdir(exist_ok=True, parents=True)
+
+ # save 1 png per polygon
+ for o, obj in enumerate(objects):
+ print(f"{i}/{len(data)} : {o}/{len(objects)}")
+
+ # create verbose label -> "cannotflood", "mustflood"
+ label = obj["value"].replace("_", "")
+ # unique polygon mask filename
+ m_path = m_out / f"{stem}_{label}_{o}.png"
+ # download address for curl
+ uri = obj["instanceURI"]
+ # command to download the image
+ command = f'curl {uri} > "{str(m_path)}"'
+ # execute command
+ os.system(command)
+ print("#" * 20)
diff --git a/utils_scripts/make-labelbox.sh b/utils_scripts/make-labelbox.sh
new file mode 100644
index 0000000000000000000000000000000000000000..d649238546b996ae1ad00f7753c04aeebc7aaa97
--- /dev/null
+++ b/utils_scripts/make-labelbox.sh
@@ -0,0 +1,9 @@
+echo "Dowloading Script" && python download_labelbox.py
+
+echo "Merging Script" && python merge_labelbox_masks.py
+
+echo "Cleaning labeled"
+rm /Users/victor/Downloads/metrics-v2/labels/*
+cp /Users/victor/Downloads/labelbox_test_flood-v2/__labeled/* /Users/victor/Downloads/metrics-v2/labels
+
+echo "Create labeled images Script" && python create_labeled.py
\ No newline at end of file
diff --git a/utils_scripts/make_640_masker_validation_set.py b/utils_scripts/make_640_masker_validation_set.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ca26f553387d65b9bc9fd47074f7e54d10c48e4
--- /dev/null
+++ b/utils_scripts/make_640_masker_validation_set.py
@@ -0,0 +1,198 @@
+import sys
+from pathlib import Path
+from skimage.io import imread, imsave
+from skimage.transform import resize
+from skimage.color import rgba2rgb
+from argparse import ArgumentParser
+import numpy as np
+
+IMG_EXTENSIONS = set(
+ [".jpg", ".JPG", ".jpeg", ".JPEG", ".png", ".PNG", ".ppm", ".PPM", ".bmp", ".BMP"]
+)
+
+
+def is_image_file(filename):
+ """Check that a file's name points to a known image format
+ """
+ if isinstance(filename, Path):
+ return filename.suffix in IMG_EXTENSIONS
+
+ return Path(filename).suffix in IMG_EXTENSIONS
+
+
+def find_images(path, recursive=False):
+ """
+ Get a list of all images contained in a directory:
+
+ - path.glob("*") if not recursive
+ - path.glob("**/*") if recursive
+ """
+ p = Path(path)
+ assert p.exists()
+ assert p.is_dir()
+ pattern = "*"
+ if recursive:
+ pattern += "*/*"
+
+ return [i for i in p.glob(pattern) if i.is_file() and is_image_file(i)]
+
+
+def uint8(array):
+ return array.astype(np.uint8)
+
+
+def crop_and_resize(image_path, label_path):
+ """
+ Resizes an image so that it keeps the aspect ratio and the smallest dimensions
+ is 640, then crops this resized image in its center so that the output is 640x640
+ without aspect ratio distortion
+
+ Args:
+ image_path (Path or str): Path to an image
+ label_path (Path or str): Path to the image's associated label
+
+ Returns:
+ tuple((np.ndarray, np.ndarray)): (new image, new label)
+ """
+ dolab = label_path is not None
+
+ img = imread(image_path)
+ if dolab:
+ lab = imread(label_path)
+
+ if img.shape[-1] == 4:
+ img = uint8(rgba2rgb(img) * 255)
+
+ if dolab and img.shape != lab.shape:
+ print("\nWARNING: shape mismatch. Entering breakpoint to investigate:")
+ breakpoint()
+
+ # resize keeping aspect ratio: smallest dim is 640
+ h, w = img.shape[:2]
+ if h < w:
+ size = (640, int(640 * w / h))
+ else:
+ size = (int(640 * h / w), 640)
+
+ r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
+ r_img = uint8(r_img)
+
+ if dolab:
+ # nearest neighbor for labels
+ r_lab = resize(lab, size, preserve_range=True, anti_aliasing=False, order=0)
+ r_lab = uint8(r_lab)
+
+ # crop in the center
+ H, W = r_img.shape[:2]
+
+ top = (H - 640) // 2
+ left = (W - 640) // 2
+
+ rc_img = r_img[top : top + 640, left : left + 640, :]
+ if dolab:
+ rc_lab = r_lab[top : top + 640, left : left + 640, :]
+ else:
+ rc_lab = None
+
+ return rc_img, rc_lab
+
+
+def label(img, label, alpha=0.4):
+ return uint8(alpha * label + (1 - alpha) * img)
+
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument(
+ "-i", "--input_dir", type=str, help="Directory to recursively read images from"
+ )
+ parser.add_argument(
+ "-o",
+ "--output_dir",
+ type=str,
+ help="Where to writ the result of the script,"
+ + " keeping the input dir's structure",
+ )
+ parser.add_argument(
+ "--no_labels",
+ action="store_true",
+ help="Only process images, don't look for labels",
+ )
+ parser.add_argument(
+ "--store_labeled",
+ action="store_true",
+ help="Store a superposition of the label and the image in out/labeled/",
+ )
+ args = parser.parse_args()
+
+ dolab = not args.no_labels
+ dolabeled = args.store_labeled
+
+ input_base = Path(args.input_dir).expanduser().resolve()
+ output_base = Path(args.output_dir).expanduser().resolve()
+
+ input_images = input_base / "imgs"
+ output_images = output_base / "imgs"
+
+ if dolab:
+ input_labels = input_base / "labels"
+ output_labels = output_base / "labels"
+ if dolabeled:
+ output_labeled = output_base / "labeled"
+
+ print("Input images:", str(input_images))
+ print("Output images:", str(output_images))
+ if dolab:
+ print("Input labels:", str(input_labels))
+ print("Output labels:", str(output_labels))
+ if dolabeled:
+ print("Output labeled:", str(output_labeled))
+ else:
+ print("NO LABEL PROCESSING (args.no_labels is specified)")
+ print()
+
+ assert input_images.exists()
+ if dolab:
+ assert input_labels.exists()
+
+ if output_base.exists():
+ if (
+ "n"
+ in input(
+ "WARNING: output dir already exists."
+ + " Overwrite its content? (y/n, default: y)"
+ ).lower()
+ ):
+ sys.exit()
+
+ output_images.mkdir(parents=True, exist_ok=True)
+ if dolab:
+ output_labels.mkdir(parents=True, exist_ok=True)
+ if dolabeled:
+ output_labeled.mkdir(parents=True, exist_ok=True)
+
+ images_paths = list(
+ map(Path, sorted((map(str, find_images(input_images, recursive=True)))))
+ )
+ if dolab:
+ labels_paths = list(
+ map(Path, sorted((map(str, find_images(input_labels, recursive=True)))))
+ )
+ else:
+ labels_paths = [None] * len(images_paths)
+
+ for i, (image_path, label_path) in enumerate(zip(images_paths, labels_paths)):
+ print(
+ f"Processing {i + 1 :3} / {len(images_paths)} : {image_path.name}",
+ end="\r",
+ flush=True,
+ )
+ processed_image, processed_label = crop_and_resize(image_path, label_path)
+ imsave(output_images / f"{image_path.stem}.png", processed_image)
+ if dolab:
+ imsave(output_labels / f"{label_path.stem}.png", processed_label)
+ if dolabeled:
+ labeled = label(processed_image, processed_label)
+ imsave(output_labeled / f"{image_path.stem}.png", labeled)
+
+ print("\nDone.")
diff --git a/utils_scripts/merge_labelbox_masks.py b/utils_scripts/merge_labelbox_masks.py
new file mode 100644
index 0000000000000000000000000000000000000000..34a2df93996e94d89c81054f4f4a53766c704d95
--- /dev/null
+++ b/utils_scripts/merge_labelbox_masks.py
@@ -0,0 +1,41 @@
+from pathlib import Path
+
+import numpy as np
+from skimage.io import imread, imsave
+from shutil import copyfile
+
+if __name__ == "__main__":
+ # output of download_labelbox.py
+ base_dir = Path("/Users/victor/Downloads/labelbox_test_flood-v2")
+ labeled_dir = base_dir / "__labeled"
+ assert base_dir.exists()
+ labeled_dir.mkdir(exist_ok=True)
+
+ sub_dirs = [
+ d
+ for d in base_dir.expanduser().resolve().iterdir()
+ if d.is_dir() and not d.name.startswith(".") and d.name != "__labeled"
+ ]
+
+ for k, sd in enumerate(sub_dirs):
+ print(k + 1, "/", len(sub_dirs), sd.name)
+
+ # must-flood binary mask
+ must = np.stack([imread(i)[:, :, :3] for i in sd.glob("*must*.png")]).sum(0) > 0
+ # cannot-flood binary mask
+ cannot = (
+ np.stack([imread(i)[:, :, :3] for i in sd.glob("*cannot*.png")]).sum(0) > 0
+ )
+ # must is red
+ must = (must * [0, 0, 255]).astype(np.uint8)
+ # connot is blue
+ cannot = (cannot * [255, 0, 0]).astype(np.uint8)
+ # merged labels
+ label = must + cannot
+ # check no overlap
+ assert sorted(np.unique(label)) == [0, 255]
+ # create filename
+ stem = "_".join(list(sd.glob("*must*.png"))[0].stem.split("_")[:-2])
+ # save label
+ imsave(sd / f"{stem}_labeled.png", label)
+ copyfile(sd / f"{stem}_labeled.png", labeled_dir / f"{stem}_labeled.png")
diff --git a/utils_scripts/upload_images_to_comet.py b/utils_scripts/upload_images_to_comet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8b86f9f5733291cb410721d7bbd5fa780c6cc3f
--- /dev/null
+++ b/utils_scripts/upload_images_to_comet.py
@@ -0,0 +1,26 @@
+import comet_ml # noqa: F401
+from pathlib import Path
+import sys
+from argparse import ArgumentParser
+
+sys.path.append(str(Path(__file__).resolve().parent.parent))
+
+from climategan.utils import upload_images_to_exp
+
+if __name__ == "__main__":
+ parser = ArgumentParser()
+ parser.add_argument("-i", "--images_path", type=str, default=".")
+ parser.add_argument("-p", "--project_name", type=str, default="climategan-eval")
+ parser.add_argument("-s", "--sleep", type=int, default=0.1)
+ parser.add_argument("-v", "--verbose", type=int, default=1)
+ args = parser.parse_args()
+
+ exp = upload_images_to_exp(
+ Path(args.images_path).resolve(),
+ exp=None,
+ project_name=args.project_name,
+ sleep=args.sleep,
+ verbose=args.verbose,
+ )
+
+ exp.end()