Spaces:
Runtime error
Runtime error
import sys | |
from argparse import ArgumentParser | |
from pathlib import Path | |
from comet_ml import Experiment | |
import numpy as np | |
import torch | |
import yaml | |
from PIL import Image | |
from skimage.color import gray2rgb | |
from skimage.io import imread | |
from skimage.transform import resize | |
from skimage.util import img_as_ubyte | |
from tqdm import tqdm | |
sys.path.append(str(Path(__file__).resolve().parent.parent)) | |
import climategan | |
GROUND_MODEL = "/miniscratch/_groups/ccai/experiments/runs/ablation-v1/out--ground" | |
def uint8(array): | |
return array.astype(np.uint8) | |
def crop_and_resize(image_path, label_path): | |
""" | |
Resizes an image so that it keeps the aspect ratio and the smallest dimensions | |
is 640, then crops this resized image in its center so that the output is 640x640 | |
without aspect ratio distortion | |
Args: | |
image_path (Path or str): Path to an image | |
label_path (Path or str): Path to the image's associated label | |
Returns: | |
tuple((np.ndarray, np.ndarray)): (new image, new label) | |
""" | |
img = imread(image_path) | |
lab = imread(label_path) | |
# if img.shape[-1] == 4: | |
# img = uint8(rgba2rgb(img) * 255) | |
# TODO: remove (debug) | |
if img.shape[:2] != lab.shape[:2]: | |
print( | |
"\nWARNING: shape mismatch: im -> {}, lab -> {}".format( | |
image_path.name, label_path.name | |
) | |
) | |
# breakpoint() | |
# resize keeping aspect ratio: smallest dim is 640 | |
h, w = img.shape[:2] | |
if h < w: | |
size = (640, int(640 * w / h)) | |
else: | |
size = (int(640 * h / w), 640) | |
r_img = resize(img, size, preserve_range=True, anti_aliasing=True) | |
r_img = uint8(r_img) | |
r_lab = resize(lab, size, preserve_range=True, anti_aliasing=False, order=0) | |
r_lab = uint8(r_lab) | |
# crop in the center | |
H, W = r_img.shape[:2] | |
top = (H - 640) // 2 | |
left = (W - 640) // 2 | |
rc_img = r_img[top : top + 640, left : left + 640, :] | |
rc_lab = ( | |
r_lab[top : top + 640, left : left + 640, :] | |
if r_lab.ndim == 3 | |
else r_lab[top : top + 640, left : left + 640] | |
) | |
return rc_img, rc_lab | |
def load_ground(ground_output_path, ref_image_path): | |
gop = Path(ground_output_path) | |
rip = Path(ref_image_path) | |
ground_paths = list((gop / "eval-metrics" / "pred").glob(f"{rip.stem}.jpg")) + list( | |
(gop / "eval-metrics" / "pred").glob(f"{rip.stem}.png") | |
) | |
if len(ground_paths) == 0: | |
raise ValueError( | |
f"Could not find a ground match in {str(gop)} for image {str(rip)}" | |
) | |
elif len(ground_paths) > 1: | |
raise ValueError( | |
f"Found more than 1 ground match in {str(gop)} for image {str(rip)}:" | |
+ f" {list(map(str, ground_paths))}" | |
) | |
ground_path = ground_paths[0] | |
_, ground = crop_and_resize(rip, ground_path) | |
ground = (ground > 0).astype(np.float32) | |
return torch.from_numpy(ground).unsqueeze(0).unsqueeze(0).cuda() | |
def parse_args(): | |
parser = ArgumentParser() | |
parser.add_argument("-y", "--yaml", help="Path to a list of models") | |
parser.add_argument( | |
"--disable_loading", | |
action="store_true", | |
default=False, | |
help="Disable loading of existing inferences", | |
) | |
parser.add_argument( | |
"-t", "--tags", nargs="*", help="Comet.ml tags", default=[], type=str | |
) | |
parser.add_argument( | |
"--tasks", | |
nargs="*", | |
help="Comet.ml tags", | |
default=["x", "d", "s", "m", "mx", "p"], | |
type=str, | |
) | |
args = parser.parse_args() | |
print("Received args:") | |
print(vars(args)) | |
return args | |
def load_images_and_labels( | |
path="/miniscratch/_groups/ccai/data/omnigan/masker-test-set", | |
): | |
p = Path(path) | |
ims_path = p / "imgs" | |
lab_path = p / "labels" | |
ims = sorted(climategan.utils.find_images(ims_path), key=lambda x: x.name) | |
labs = sorted( | |
climategan.utils.find_images(lab_path), | |
key=lambda x: x.name.replace("_labeled.", "."), | |
) | |
xs = climategan.transforms.PrepareInference()(ims) | |
ys = climategan.transforms.PrepareInference(is_label=True)(labs) | |
return xs, ys, ims, labs | |
def load_inferences(inf_path, im_paths): | |
try: | |
assert inf_path.exists() | |
assert sorted([i.stem for i in im_paths]) == sorted( | |
[i.stem for i in inf_path.glob("*.pt")] | |
) | |
return [torch.load(str(i)) for i in tqdm(list(inf_path.glob("*.pt")))] | |
except Exception as e: | |
print() | |
print(e) | |
print("Aborting Loading") | |
print() | |
return None | |
def get_or_load_inferences( | |
m_path, device, xs, is_ground, im_paths, ground_model, try_load=True | |
): | |
inf_path = Path(m_path) / "inferences" | |
if try_load: | |
print("Trying to load existing inferences:") | |
outputs = load_inferences(inf_path, im_paths) | |
if outputs is not None: | |
print("Successfully loaded existing inferences") | |
return outputs | |
trainer = climategan.trainer.Trainer.resume_from_path( | |
m_path if not is_ground else ground_model, | |
inference=True, | |
new_exp=None, | |
device=device, | |
) | |
inf_path.mkdir(exist_ok=True) | |
outputs = [] | |
for i, x in enumerate(tqdm(xs)): | |
x = x.to(trainer.device) | |
if not is_ground: | |
out = trainer.G.decode(x=x) | |
else: | |
out = {"m": load_ground(GROUND_MODEL, im_paths[i])} | |
out["p"] = trainer.G.paint(out["m"] > 0.5, x) | |
out["x"] = x | |
inference = {k: v.cpu() for k, v in out.items()} | |
outputs.append(inference) | |
torch.save(inference, inf_path / f"{im_paths[i].stem}.pt") | |
print() | |
return outputs | |
def numpify(outputs): | |
nps = [] | |
print("Numpifying...") | |
for o in tqdm(outputs): | |
x = (o["x"][0].permute(1, 2, 0).numpy() + 1) / 2 | |
m = o["m"] | |
m = (m[0, 0, :, :].numpy() > 0.5).astype(np.uint8) | |
p = (o["p"][0].permute(1, 2, 0).numpy() + 1) / 2 | |
data = {"m": m, "p": p, "x": x} | |
if "s" in o: | |
s = climategan.data.decode_segmap_merged_labels(o["s"], "r", False) / 255.0 | |
data["s"] = s[0].permute(1, 2, 0).numpy() | |
if "d" in o: | |
d = climategan.tutils.normalize_tensor(o["d"]).squeeze().numpy() | |
data["d"] = d | |
nps.append({k: img_as_ubyte(v) for k, v in data.items()}) | |
return nps | |
def concat_npy_for_model(data, tasks): | |
assert "m" in data | |
assert "x" in data | |
assert "p" in data | |
x = mask = depth = seg = painted = masked = None | |
x = data["x"] | |
painted = data["p"] | |
mask = (gray2rgb(data["m"]) * 255).astype(np.uint8) | |
painted = data["p"] | |
masked = (1 - gray2rgb(data["m"])) * x | |
concats = [] | |
if "d" in data: | |
depth = img_as_ubyte( | |
gray2rgb( | |
resize(data["d"], data["x"].shape[:2], anti_aliasing=True, order=1) | |
) | |
) | |
else: | |
depth = np.ones_like(data["x"]) * 255 | |
if "s" in data: | |
seg = img_as_ubyte( | |
resize(data["s"], data["x"].shape[:2], anti_aliasing=False, order=0) | |
) | |
else: | |
seg = np.ones_like(data["x"]) * 255 | |
for t in tasks: | |
if t == "x": | |
concats.append(x) | |
if t == "m": | |
concats.append(mask) | |
elif t == "mx": | |
concats.append(masked) | |
elif t == "d": | |
concats.append(depth) | |
elif t == "s": | |
concats.append(seg) | |
elif t == "p": | |
concats.append(painted) | |
row = np.concatenate(concats, axis=1) | |
return row | |
if __name__ == "__main__": | |
args = parse_args() | |
with open(args.yaml, "r") as f: | |
maskers = yaml.safe_load(f) | |
if "models" in maskers: | |
maskers = maskers["models"] | |
load = not args.disable_loading | |
tags = args.tags | |
tasks = args.tasks | |
ground_model = None | |
for m in maskers: | |
if "ground" not in maskers: | |
ground_model = m | |
break | |
if ground_model is None: | |
raise ValueError("Could not find a non-ground model to get a painter") | |
device = torch.device("cpu") | |
torch.set_grad_enabled(False) | |
xs, ys, im_paths, lab_paths = load_images_and_labels() | |
np_outs = {} | |
names = [] | |
for m_path in maskers: | |
opt_path = Path(m_path) / "opts.yaml" | |
with opt_path.open("r") as f: | |
opt = yaml.safe_load(f) | |
name = ( | |
", ".join( | |
[ | |
t | |
for t in sorted(opt["comet"]["tags"]) | |
if "branch" not in t and "ablation" not in t and "trash" not in t | |
] | |
) | |
if "--ground" not in m_path | |
else "ground" | |
) | |
names.append(name) | |
is_ground = name == "ground" | |
print("#" * 100) | |
print("\n>>> Processing", name) | |
print() | |
outputs = get_or_load_inferences( | |
m_path, device, xs, is_ground, im_paths, ground_model, load | |
) | |
nps = numpify(outputs) | |
np_outs[name] = nps | |
exp = Experiment(project_name="climategan-inferences", display_summary_level=0) | |
exp.log_parameter("names", names) | |
exp.add_tags(tags) | |
for i in tqdm(range(len(xs))): | |
all_models_for_image = [] | |
for name in names: | |
xpmds = concat_npy_for_model(np_outs[name][i], tasks) | |
all_models_for_image.append(xpmds) | |
full_im = np.concatenate(all_models_for_image, axis=0) | |
pil_im = Image.fromarray(full_im) | |
exp.log_image(pil_im, name=im_paths[i].stem.replace(".", "_"), step=i) | |