|
import argparse |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"-b", |
|
"--batch_size", |
|
type=int, |
|
default=4, |
|
help="Batch size to process input images to events. Defaults to 4", |
|
) |
|
parser.add_argument( |
|
"-i", |
|
"--images_paths", |
|
type=str, |
|
required=True, |
|
help="Path to a directory with image files", |
|
) |
|
parser.add_argument( |
|
"-o", |
|
"--output_path", |
|
type=str, |
|
default=None, |
|
help="Path to a directory were events should be written. " |
|
+ "Will NOT write anything to disk if this flag is not used.", |
|
) |
|
parser.add_argument( |
|
"-s", |
|
"--save_input", |
|
action="store_true", |
|
default=False, |
|
help="Binary flag to include the input image to the model (after crop and" |
|
+ " resize) in the images written or uploaded (depending on saving options.)", |
|
) |
|
parser.add_argument( |
|
"-r", |
|
"--resume_path", |
|
type=str, |
|
default=None, |
|
help="Path to a directory containing the trainer to resume." |
|
+ " In particular it must contain `opts.yam` and `checkpoints/`." |
|
+ " Typically this points to a Masker, which holds the path to a" |
|
+ " Painter in its opts", |
|
) |
|
parser.add_argument( |
|
"--no_time", |
|
action="store_true", |
|
default=False, |
|
help="Binary flag to prevent the timing of operations.", |
|
) |
|
parser.add_argument( |
|
"-f", |
|
"--flood_mask_binarization", |
|
type=float, |
|
default=0.5, |
|
help="Value to use to binarize masks (mask > value). " |
|
+ "Set to -1 to use soft masks (not binarized). Defaults to 0.5.", |
|
) |
|
parser.add_argument( |
|
"-t", |
|
"--target_size", |
|
type=int, |
|
default=640, |
|
help="Output image size (when not using `keep_ratio_128`): images are resized" |
|
+ " such that their smallest side is `target_size` then cropped in the middle" |
|
+ " of the largest side such that the resulting input image (and output images)" |
|
+ " has height and width `target_size x target_size`. **Must** be a multiple of" |
|
+ " 2^7=128 (up/downscaling inside the models). Defaults to 640.", |
|
) |
|
parser.add_argument( |
|
"--half", |
|
action="store_true", |
|
default=False, |
|
help="Binary flag to use half precision (float16). Defaults to False.", |
|
) |
|
parser.add_argument( |
|
"-n", |
|
"--n_images", |
|
default=-1, |
|
type=int, |
|
help="Limit the number of images processed (if you have 100 images in " |
|
+ "a directory but n is 10 then only the first 10 images will be loaded" |
|
+ " for processing)", |
|
) |
|
parser.add_argument( |
|
"--no_conf", |
|
action="store_true", |
|
default=False, |
|
help="disable writing the apply_events hash and command in the output folder", |
|
) |
|
parser.add_argument( |
|
"--overwrite", |
|
action="store_true", |
|
default=False, |
|
help="Do not check for existing outdir, i.e. force overwrite" |
|
+ " potentially existing files in the output path", |
|
) |
|
parser.add_argument( |
|
"--no_cloudy", |
|
action="store_true", |
|
default=False, |
|
help="Prevent the use of the cloudy intermediate" |
|
+ " image to create the flood image. Rendering will" |
|
+ " be more colorful but may seem less realistic", |
|
) |
|
parser.add_argument( |
|
"--keep_ratio_128", |
|
action="store_true", |
|
default=False, |
|
help="When loading the input images, resize and crop them in order for their " |
|
+ "dimensions to match the closest multiples" |
|
+ " of 128. Will force a batch size of 1 since images" |
|
+ " now have different dimensions. " |
|
+ "Use --max_im_width to cap the resulting dimensions.", |
|
) |
|
parser.add_argument( |
|
"--fuse", |
|
action="store_true", |
|
default=False, |
|
help="Use batch norm fusion to speed up inference", |
|
) |
|
parser.add_argument( |
|
"--save_masks", |
|
action="store_true", |
|
default=False, |
|
help="Save output masks along events", |
|
) |
|
parser.add_argument( |
|
"-m", |
|
"--max_im_width", |
|
type=int, |
|
default=-1, |
|
help="When using --keep_ratio_128, some images may still be too large. Use " |
|
+ "--max_im_width to cap the resized image's width. Defaults to -1 (no cap).", |
|
) |
|
parser.add_argument( |
|
"--upload", |
|
action="store_true", |
|
help="Upload to comet.ml in a project called `climategan-apply`", |
|
) |
|
parser.add_argument( |
|
"--zip_outdir", |
|
"-z", |
|
action="store_true", |
|
help="Zip the output directory as '{outdir.parent}/{outdir.name}.zip'", |
|
) |
|
return parser.parse_args() |
|
|
|
|
|
args = parse_args() |
|
|
|
|
|
print("\n• Imports\n") |
|
import time |
|
|
|
import_time = time.time() |
|
import sys |
|
import shutil |
|
from collections import OrderedDict |
|
from pathlib import Path |
|
|
|
import comet_ml |
|
import torch |
|
import numpy as np |
|
import skimage.io as io |
|
from skimage.color import rgba2rgb |
|
from skimage.transform import resize |
|
from tqdm import tqdm |
|
|
|
from climategan.trainer import Trainer |
|
from climategan.bn_fusion import bn_fuse |
|
from climategan.tutils import print_num_parameters |
|
from climategan.utils import Timer, find_images, get_git_revision_hash, to_128, resolve |
|
|
|
import_time = time.time() - import_time |
|
|
|
|
|
def to_m1_p1(img, i): |
|
""" |
|
rescales a [0, 1] image to [-1, +1] |
|
|
|
Args: |
|
img (np.array): float32 numpy array of an image in [0, 1] |
|
i (int): Index of the image being rescaled |
|
|
|
Raises: |
|
ValueError: If the image is not in [0, 1] |
|
|
|
Returns: |
|
np.array(np.float32): array in [-1, +1] |
|
""" |
|
if img.min() >= 0 and img.max() <= 1: |
|
return (img.astype(np.float32) - 0.5) * 2 |
|
raise ValueError(f"Data range mismatch for image {i} : ({img.min()}, {img.max()})") |
|
|
|
|
|
def uint8(array): |
|
""" |
|
convert an array to np.uint8 (does not rescale or anything else than changing dtype) |
|
|
|
Args: |
|
array (np.array): array to modify |
|
|
|
Returns: |
|
np.array(np.uint8): converted array |
|
""" |
|
return array.astype(np.uint8) |
|
|
|
|
|
def resize_and_crop(img, to=640): |
|
""" |
|
Resizes an image so that it keeps the aspect ratio and the smallest dimensions |
|
is `to`, then crops this resized image in its center so that the output is `to x to` |
|
without aspect ratio distortion |
|
|
|
Args: |
|
img (np.array): np.uint8 255 image |
|
|
|
Returns: |
|
np.array: [0, 1] np.float32 image |
|
""" |
|
|
|
h, w = img.shape[:2] |
|
if h < w: |
|
size = (to, int(to * w / h)) |
|
else: |
|
size = (int(to * h / w), to) |
|
|
|
r_img = resize(img, size, preserve_range=True, anti_aliasing=True) |
|
r_img = uint8(r_img) |
|
|
|
|
|
H, W = r_img.shape[:2] |
|
|
|
top = (H - to) // 2 |
|
left = (W - to) // 2 |
|
|
|
rc_img = r_img[top : top + to, left : left + to, :] |
|
|
|
return rc_img / 255.0 |
|
|
|
|
|
def print_time(text, time_series, purge=-1): |
|
""" |
|
Print a timeseries's mean and std with a label |
|
|
|
Args: |
|
text (str): label of the time series |
|
time_series (list): list of timings |
|
purge (int, optional): ignore first n values of time series. Defaults to -1. |
|
""" |
|
if not time_series: |
|
return |
|
|
|
if purge > 0 and len(time_series) > purge: |
|
time_series = time_series[purge:] |
|
|
|
m = np.mean(time_series) |
|
s = np.std(time_series) |
|
|
|
print( |
|
f"{text.capitalize() + ' ':.<26} {m:.5f}" |
|
+ (f" +/- {s:.5f}" if len(time_series) > 1 else "") |
|
) |
|
|
|
|
|
def print_store(store, purge=-1): |
|
""" |
|
Pretty-print time series store |
|
|
|
Args: |
|
store (dict): maps string keys to lists of times |
|
purge (int, optional): ignore first n values of time series. Defaults to -1. |
|
""" |
|
singles = OrderedDict({k: v for k, v in store.items() if len(v) == 1}) |
|
multiples = OrderedDict({k: v for k, v in store.items() if len(v) > 1}) |
|
empties = {k: v for k, v in store.items() if len(v) == 0} |
|
|
|
if empties: |
|
print("Ignoring empty stores ", ", ".join(empties.keys())) |
|
print() |
|
|
|
for k in singles: |
|
print_time(k, singles[k], purge) |
|
|
|
print() |
|
print("Unit: s/batch") |
|
for k in multiples: |
|
print_time(k, multiples[k], purge) |
|
print() |
|
|
|
|
|
def write_apply_config(out): |
|
""" |
|
Saves the args to `apply_events.py` in a text file for future reference |
|
""" |
|
cwd = Path.cwd().expanduser().resolve() |
|
command = f"cd {str(cwd)}\n" |
|
command += " ".join(sys.argv) |
|
git_hash = get_git_revision_hash() |
|
with (out / "command.txt").open("w") as f: |
|
f.write(command) |
|
with (out / "hash.txt").open("w") as f: |
|
f.write(git_hash) |
|
|
|
|
|
def get_outdir_name(half, keep_ratio, max_im_width, target_size, bin_value, cloudy): |
|
""" |
|
Create the output directory's name based on uer-provided arguments |
|
""" |
|
name_items = [] |
|
if half: |
|
name_items.append("half") |
|
if keep_ratio: |
|
name_items.append("AR") |
|
if max_im_width and keep_ratio: |
|
name_items.append(f"{max_im_width}") |
|
if target_size and not keep_ratio: |
|
name_items.append("S") |
|
name_items.append(f"{target_size}") |
|
if bin_value != 0.5: |
|
name_items.append(f"bin{bin_value}") |
|
if not cloudy: |
|
name_items.append("no_cloudy") |
|
|
|
return "-".join(name_items) |
|
|
|
|
|
def make_outdir( |
|
outdir, overwrite, half, keep_ratio, max_im_width, target_size, bin_value, cloudy |
|
): |
|
""" |
|
Creates the output directory if it does not exist. If it does exist, |
|
prompts the user for confirmation (except if `overwrite` is True). |
|
If the output directory's name is "_auto_" then it is created as: |
|
outdir.parent / get_outdir_name(...) |
|
""" |
|
if outdir.name == "_auto_": |
|
outdir = outdir.parent / get_outdir_name( |
|
half, keep_ratio, max_im_width, target_size, bin_value, cloudy |
|
) |
|
if outdir.exists() and not overwrite: |
|
print( |
|
f"\nWARNING: outdir ({str(outdir)}) already exists." |
|
+ " Files with existing names will be overwritten" |
|
) |
|
if "n" in input(">>> Continue anyway? [y / n] (default: y) : "): |
|
print("Interrupting execution from user input.") |
|
sys.exit() |
|
print() |
|
outdir.mkdir(exist_ok=True, parents=True) |
|
return outdir |
|
|
|
|
|
def get_time_stores(import_time): |
|
return OrderedDict( |
|
{ |
|
"imports": [import_time], |
|
"setup": [], |
|
"data pre-processing": [], |
|
"encode": [], |
|
"mask": [], |
|
"flood": [], |
|
"depth": [], |
|
"segmentation": [], |
|
"smog": [], |
|
"wildfire": [], |
|
"all events": [], |
|
"numpy": [], |
|
"inference on all images": [], |
|
"write": [], |
|
} |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
print( |
|
"• Using args\n\n" |
|
+ "\n".join(["{:25}: {}".format(k, v) for k, v in vars(args).items()]), |
|
) |
|
|
|
batch_size = args.batch_size |
|
bin_value = args.flood_mask_binarization |
|
cloudy = not args.no_cloudy |
|
fuse = args.fuse |
|
half = args.half |
|
save_masks = args.save_masks |
|
images_paths = resolve(args.images_paths) |
|
keep_ratio = args.keep_ratio_128 |
|
max_im_width = args.max_im_width |
|
n_images = args.n_images |
|
outdir = resolve(args.output_path) if args.output_path is not None else None |
|
resume_path = args.resume_path |
|
target_size = args.target_size |
|
time_inference = not args.no_time |
|
upload = args.upload |
|
zip_outdir = args.zip_outdir |
|
|
|
|
|
|
|
|
|
if keep_ratio: |
|
if target_size != 640: |
|
print( |
|
"\nWARNING: using --keep_ratio_128 overwrites target_size" |
|
+ " which is ignored." |
|
) |
|
if batch_size != 1: |
|
print("\nWARNING: batch_size overwritten to 1 when using keep_ratio_128") |
|
batch_size = 1 |
|
if max_im_width > 0 and max_im_width % 128 != 0: |
|
new_im_width = int(max_im_width / 128) * 128 |
|
print("\nWARNING: max_im_width should be <0 or a multiple of 128.") |
|
print( |
|
" Was {} but is now overwritten to {}".format( |
|
max_im_width, new_im_width |
|
) |
|
) |
|
max_im_width = new_im_width |
|
else: |
|
if target_size % 128 != 0: |
|
print(f"\nWarning: target size {target_size} is not a multiple of 128.") |
|
target_size = target_size - (target_size % 128) |
|
print(f"Setting target_size to {target_size}.") |
|
|
|
|
|
|
|
|
|
if outdir is not None: |
|
outdir = make_outdir( |
|
outdir, |
|
args.overwrite, |
|
half, |
|
keep_ratio, |
|
max_im_width, |
|
target_size, |
|
bin_value, |
|
cloudy, |
|
) |
|
|
|
|
|
|
|
|
|
stores = get_time_stores(import_time) |
|
|
|
|
|
|
|
|
|
with Timer(store=stores.get("setup", []), ignore=time_inference): |
|
print("\n• Initializing trainer\n") |
|
torch.set_grad_enabled(False) |
|
trainer = Trainer.resume_from_path( |
|
resume_path, |
|
setup=True, |
|
inference=True, |
|
new_exp=None, |
|
) |
|
print() |
|
print_num_parameters(trainer, True) |
|
if fuse: |
|
trainer.G = bn_fuse(trainer.G) |
|
if half: |
|
trainer.G.half() |
|
|
|
|
|
|
|
|
|
print("\n• Reading & Pre-processing Data\n") |
|
|
|
|
|
data_paths = find_images(images_paths) |
|
base_data_paths = data_paths |
|
|
|
if 0 < n_images < len(data_paths): |
|
data_paths = data_paths[:n_images] |
|
|
|
elif n_images > len(data_paths): |
|
repeats = n_images // len(data_paths) + 1 |
|
data_paths = base_data_paths * repeats |
|
data_paths = data_paths[:n_images] |
|
|
|
with Timer(store=stores.get("data pre-processing", []), ignore=time_inference): |
|
|
|
data = [io.imread(str(d)) for d in data_paths] |
|
|
|
data = [im if im.shape[-1] == 3 else uint8(rgba2rgb(im) * 255) for im in data] |
|
|
|
if keep_ratio: |
|
|
|
new_sizes = [to_128(d, max_im_width) for d in data] |
|
data = [resize(d, ns, anti_aliasing=True) for d, ns in zip(data, new_sizes)] |
|
else: |
|
|
|
data = [resize_and_crop(d, target_size) for d in data] |
|
new_sizes = [(target_size, target_size) for _ in data] |
|
|
|
data = [to_m1_p1(d, i) for i, d in enumerate(data)] |
|
|
|
n_batchs = len(data) // batch_size |
|
if len(data) % batch_size != 0: |
|
n_batchs += 1 |
|
|
|
print("Found", len(base_data_paths), "images. Inferring on", len(data), "images.") |
|
|
|
|
|
|
|
|
|
print(f"\n• Using device {str(trainer.device)}\n") |
|
|
|
all_events = [] |
|
|
|
with Timer(store=stores.get("inference on all images", []), ignore=time_inference): |
|
for b in tqdm(range(n_batchs), desc="Infering events", unit="batch"): |
|
|
|
images = data[b * batch_size : (b + 1) * batch_size] |
|
if not images: |
|
continue |
|
|
|
|
|
images = np.stack(images) |
|
|
|
events = trainer.infer_all( |
|
images, |
|
numpy=True, |
|
stores=stores, |
|
bin_value=bin_value, |
|
half=half, |
|
cloudy=cloudy, |
|
return_masks=save_masks, |
|
) |
|
|
|
|
|
if args.save_input: |
|
events["input"] = uint8((images + 1) / 2 * 255) |
|
|
|
|
|
all_events.append(events) |
|
|
|
|
|
|
|
|
|
if outdir is not None or upload: |
|
|
|
if upload: |
|
print("\n• Creating comet Experiment") |
|
exp = comet_ml.Experiment(project_name="climategan-apply") |
|
exp.log_parameters(vars(args)) |
|
|
|
|
|
|
|
|
|
to_write = [] |
|
events_names = list(all_events[0].keys()) |
|
for events_data in all_events: |
|
n_ims = len(events_data[events_names[0]]) |
|
for i in range(n_ims): |
|
item = {event: events_data[event][i] for event in events_names} |
|
to_write.append(item) |
|
|
|
progress_bar_desc = "" |
|
if outdir is not None: |
|
print("\n• Output directory:\n") |
|
print(str(outdir), "\n") |
|
if upload: |
|
progress_bar_desc = "Writing & Uploading events" |
|
else: |
|
progress_bar_desc = "Writing events" |
|
else: |
|
if upload: |
|
progress_bar_desc = "Uploading events" |
|
|
|
|
|
|
|
|
|
with Timer(store=stores.get("write", []), ignore=time_inference): |
|
|
|
|
|
for t, event_dict in tqdm( |
|
enumerate(to_write), |
|
desc=progress_bar_desc, |
|
unit="input image", |
|
total=len(to_write), |
|
): |
|
|
|
idx = t % len(base_data_paths) |
|
stem = Path(data_paths[idx]).stem |
|
width = new_sizes[idx][1] |
|
|
|
if keep_ratio: |
|
ar = "_AR" |
|
else: |
|
ar = "" |
|
|
|
|
|
event_bar = tqdm( |
|
enumerate(event_dict.items()), |
|
leave=False, |
|
total=len(events_names), |
|
unit="event", |
|
) |
|
for e, (event, im_data) in event_bar: |
|
event_bar.set_description( |
|
f" {event.capitalize():<{len(progress_bar_desc) - 2}}" |
|
) |
|
|
|
if args.no_cloudy: |
|
suffix = ar + "_no_cloudy" |
|
else: |
|
suffix = ar |
|
|
|
im_path = Path(f"{stem}_{event}_{width}{suffix}.png") |
|
|
|
if outdir is not None: |
|
im_path = outdir / im_path |
|
io.imsave(im_path, im_data) |
|
|
|
if upload: |
|
exp.log_image(im_data, name=im_path.name) |
|
if zip_outdir: |
|
print("\n• Zipping output directory... ", end="", flush=True) |
|
archive_path = Path(shutil.make_archive(outdir.name, "zip", root_dir=outdir)) |
|
archive_path = archive_path.rename(outdir.parent / archive_path.name) |
|
print("Done:\n") |
|
print(str(archive_path)) |
|
|
|
|
|
|
|
|
|
if time_inference: |
|
print("\n• Timings\n") |
|
print_store(stores) |
|
|
|
|
|
|
|
|
|
if not args.no_conf and outdir is not None: |
|
write_apply_config(outdir) |
|
|