climateGAN / utils_scripts /compare_maskers.py
vict0rsch's picture
initial commit from `vict0rsch/climateGAN`
ce190ee
raw
history blame
9.61 kB
import sys
from argparse import ArgumentParser
from pathlib import Path
from comet_ml import Experiment
import numpy as np
import torch
import yaml
from PIL import Image
from skimage.color import gray2rgb
from skimage.io import imread
from skimage.transform import resize
from skimage.util import img_as_ubyte
from tqdm import tqdm
sys.path.append(str(Path(__file__).resolve().parent.parent))
import climategan
GROUND_MODEL = "/miniscratch/_groups/ccai/experiments/runs/ablation-v1/out--ground"
def uint8(array):
return array.astype(np.uint8)
def crop_and_resize(image_path, label_path):
"""
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
is 640, then crops this resized image in its center so that the output is 640x640
without aspect ratio distortion
Args:
image_path (Path or str): Path to an image
label_path (Path or str): Path to the image's associated label
Returns:
tuple((np.ndarray, np.ndarray)): (new image, new label)
"""
img = imread(image_path)
lab = imread(label_path)
# if img.shape[-1] == 4:
# img = uint8(rgba2rgb(img) * 255)
# TODO: remove (debug)
if img.shape[:2] != lab.shape[:2]:
print(
"\nWARNING: shape mismatch: im -> {}, lab -> {}".format(
image_path.name, label_path.name
)
)
# breakpoint()
# resize keeping aspect ratio: smallest dim is 640
h, w = img.shape[:2]
if h < w:
size = (640, int(640 * w / h))
else:
size = (int(640 * h / w), 640)
r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
r_img = uint8(r_img)
r_lab = resize(lab, size, preserve_range=True, anti_aliasing=False, order=0)
r_lab = uint8(r_lab)
# crop in the center
H, W = r_img.shape[:2]
top = (H - 640) // 2
left = (W - 640) // 2
rc_img = r_img[top : top + 640, left : left + 640, :]
rc_lab = (
r_lab[top : top + 640, left : left + 640, :]
if r_lab.ndim == 3
else r_lab[top : top + 640, left : left + 640]
)
return rc_img, rc_lab
def load_ground(ground_output_path, ref_image_path):
gop = Path(ground_output_path)
rip = Path(ref_image_path)
ground_paths = list((gop / "eval-metrics" / "pred").glob(f"{rip.stem}.jpg")) + list(
(gop / "eval-metrics" / "pred").glob(f"{rip.stem}.png")
)
if len(ground_paths) == 0:
raise ValueError(
f"Could not find a ground match in {str(gop)} for image {str(rip)}"
)
elif len(ground_paths) > 1:
raise ValueError(
f"Found more than 1 ground match in {str(gop)} for image {str(rip)}:"
+ f" {list(map(str, ground_paths))}"
)
ground_path = ground_paths[0]
_, ground = crop_and_resize(rip, ground_path)
ground = (ground > 0).astype(np.float32)
return torch.from_numpy(ground).unsqueeze(0).unsqueeze(0).cuda()
def parse_args():
parser = ArgumentParser()
parser.add_argument("-y", "--yaml", help="Path to a list of models")
parser.add_argument(
"--disable_loading",
action="store_true",
default=False,
help="Disable loading of existing inferences",
)
parser.add_argument(
"-t", "--tags", nargs="*", help="Comet.ml tags", default=[], type=str
)
parser.add_argument(
"--tasks",
nargs="*",
help="Comet.ml tags",
default=["x", "d", "s", "m", "mx", "p"],
type=str,
)
args = parser.parse_args()
print("Received args:")
print(vars(args))
return args
def load_images_and_labels(
path="/miniscratch/_groups/ccai/data/omnigan/masker-test-set",
):
p = Path(path)
ims_path = p / "imgs"
lab_path = p / "labels"
ims = sorted(climategan.utils.find_images(ims_path), key=lambda x: x.name)
labs = sorted(
climategan.utils.find_images(lab_path),
key=lambda x: x.name.replace("_labeled.", "."),
)
xs = climategan.transforms.PrepareInference()(ims)
ys = climategan.transforms.PrepareInference(is_label=True)(labs)
return xs, ys, ims, labs
def load_inferences(inf_path, im_paths):
try:
assert inf_path.exists()
assert sorted([i.stem for i in im_paths]) == sorted(
[i.stem for i in inf_path.glob("*.pt")]
)
return [torch.load(str(i)) for i in tqdm(list(inf_path.glob("*.pt")))]
except Exception as e:
print()
print(e)
print("Aborting Loading")
print()
return None
def get_or_load_inferences(
m_path, device, xs, is_ground, im_paths, ground_model, try_load=True
):
inf_path = Path(m_path) / "inferences"
if try_load:
print("Trying to load existing inferences:")
outputs = load_inferences(inf_path, im_paths)
if outputs is not None:
print("Successfully loaded existing inferences")
return outputs
trainer = climategan.trainer.Trainer.resume_from_path(
m_path if not is_ground else ground_model,
inference=True,
new_exp=None,
device=device,
)
inf_path.mkdir(exist_ok=True)
outputs = []
for i, x in enumerate(tqdm(xs)):
x = x.to(trainer.device)
if not is_ground:
out = trainer.G.decode(x=x)
else:
out = {"m": load_ground(GROUND_MODEL, im_paths[i])}
out["p"] = trainer.G.paint(out["m"] > 0.5, x)
out["x"] = x
inference = {k: v.cpu() for k, v in out.items()}
outputs.append(inference)
torch.save(inference, inf_path / f"{im_paths[i].stem}.pt")
print()
return outputs
def numpify(outputs):
nps = []
print("Numpifying...")
for o in tqdm(outputs):
x = (o["x"][0].permute(1, 2, 0).numpy() + 1) / 2
m = o["m"]
m = (m[0, 0, :, :].numpy() > 0.5).astype(np.uint8)
p = (o["p"][0].permute(1, 2, 0).numpy() + 1) / 2
data = {"m": m, "p": p, "x": x}
if "s" in o:
s = climategan.data.decode_segmap_merged_labels(o["s"], "r", False) / 255.0
data["s"] = s[0].permute(1, 2, 0).numpy()
if "d" in o:
d = climategan.tutils.normalize_tensor(o["d"]).squeeze().numpy()
data["d"] = d
nps.append({k: img_as_ubyte(v) for k, v in data.items()})
return nps
def concat_npy_for_model(data, tasks):
assert "m" in data
assert "x" in data
assert "p" in data
x = mask = depth = seg = painted = masked = None
x = data["x"]
painted = data["p"]
mask = (gray2rgb(data["m"]) * 255).astype(np.uint8)
painted = data["p"]
masked = (1 - gray2rgb(data["m"])) * x
concats = []
if "d" in data:
depth = img_as_ubyte(
gray2rgb(
resize(data["d"], data["x"].shape[:2], anti_aliasing=True, order=1)
)
)
else:
depth = np.ones_like(data["x"]) * 255
if "s" in data:
seg = img_as_ubyte(
resize(data["s"], data["x"].shape[:2], anti_aliasing=False, order=0)
)
else:
seg = np.ones_like(data["x"]) * 255
for t in tasks:
if t == "x":
concats.append(x)
if t == "m":
concats.append(mask)
elif t == "mx":
concats.append(masked)
elif t == "d":
concats.append(depth)
elif t == "s":
concats.append(seg)
elif t == "p":
concats.append(painted)
row = np.concatenate(concats, axis=1)
return row
if __name__ == "__main__":
args = parse_args()
with open(args.yaml, "r") as f:
maskers = yaml.safe_load(f)
if "models" in maskers:
maskers = maskers["models"]
load = not args.disable_loading
tags = args.tags
tasks = args.tasks
ground_model = None
for m in maskers:
if "ground" not in maskers:
ground_model = m
break
if ground_model is None:
raise ValueError("Could not find a non-ground model to get a painter")
device = torch.device("cuda:0")
torch.set_grad_enabled(False)
xs, ys, im_paths, lab_paths = load_images_and_labels()
np_outs = {}
names = []
for m_path in maskers:
opt_path = Path(m_path) / "opts.yaml"
with opt_path.open("r") as f:
opt = yaml.safe_load(f)
name = (
", ".join(
[
t
for t in sorted(opt["comet"]["tags"])
if "branch" not in t and "ablation" not in t and "trash" not in t
]
)
if "--ground" not in m_path
else "ground"
)
names.append(name)
is_ground = name == "ground"
print("#" * 100)
print("\n>>> Processing", name)
print()
outputs = get_or_load_inferences(
m_path, device, xs, is_ground, im_paths, ground_model, load
)
nps = numpify(outputs)
np_outs[name] = nps
exp = Experiment(project_name="climategan-inferences", display_summary_level=0)
exp.log_parameter("names", names)
exp.add_tags(tags)
for i in tqdm(range(len(xs))):
all_models_for_image = []
for name in names:
xpmds = concat_npy_for_model(np_outs[name][i], tasks)
all_models_for_image.append(xpmds)
full_im = np.concatenate(all_models_for_image, axis=0)
pil_im = Image.fromarray(full_im)
exp.log_image(pil_im, name=im_paths[i].stem.replace(".", "_"), step=i)