Spaces:
Runtime error
Runtime error
add inference script features
Browse files- climategan_wrapper.py +82 -1
climategan_wrapper.py
CHANGED
@@ -5,7 +5,7 @@ import os
|
|
5 |
import re
|
6 |
from pathlib import Path
|
7 |
from uuid import uuid4
|
8 |
-
|
9 |
import numpy as np
|
10 |
import torch
|
11 |
from diffusers import StableDiffusionInpaintPipeline
|
@@ -541,3 +541,84 @@ class ClimateGAN:
|
|
541 |
im = Image.fromarray(uint8(im))
|
542 |
imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
|
543 |
im.save(im_path.parent / (imstem + im_path.suffix))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import re
|
6 |
from pathlib import Path
|
7 |
from uuid import uuid4
|
8 |
+
from minydra import resolved_args
|
9 |
import numpy as np
|
10 |
import torch
|
11 |
from diffusers import StableDiffusionInpaintPipeline
|
|
|
541 |
im = Image.fromarray(uint8(im))
|
542 |
imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
|
543 |
im.save(im_path.parent / (imstem + im_path.suffix))
|
544 |
+
|
545 |
+
|
546 |
+
if __name__ == "__main__":
|
547 |
+
print("Run `$ python climategan_wrapper.py help` for usage instructions\n")
|
548 |
+
|
549 |
+
# parse arguments
|
550 |
+
args = resolved_args(
|
551 |
+
defaults={
|
552 |
+
"input_folder": None,
|
553 |
+
"output_folder": None,
|
554 |
+
"painter": "both",
|
555 |
+
"help": False,
|
556 |
+
}
|
557 |
+
)
|
558 |
+
|
559 |
+
# print help
|
560 |
+
if args.help:
|
561 |
+
print(
|
562 |
+
"Usage: python inference.py input_folder=/path/to/folder\n"
|
563 |
+
+ "By default inferences will be stored in the input folder.\n"
|
564 |
+
+ "Add `output_folder=/path/to/folder` for a different output folder.\n"
|
565 |
+
+ "By default, both ClimateGAN and Stable Diffusion will be used."
|
566 |
+
+ "Change this by adding `painter=climategan` or"
|
567 |
+
+ " `painter=stable_diffusion`.\n"
|
568 |
+
+ "Make sure you have agreed to the terms of use for the models."
|
569 |
+
+ "In particular, visit SD's model card to agree to the terms of use:"
|
570 |
+
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting"
|
571 |
+
)
|
572 |
+
# print args
|
573 |
+
args.pretty_print()
|
574 |
+
|
575 |
+
# load models
|
576 |
+
cg = ClimateGAN("models/climategan")
|
577 |
+
|
578 |
+
# check painter type
|
579 |
+
assert args.painter in {"climategan", "stable_diffusion", "both",}, (
|
580 |
+
f"Unknown painter {args.painter}. "
|
581 |
+
+ "Allowed values are 'climategan', 'stable_diffusion' and 'both'."
|
582 |
+
)
|
583 |
+
|
584 |
+
# load SD pipeline if need be
|
585 |
+
if args.painter != "climate_gan":
|
586 |
+
cg._setup_stable_diffusion()
|
587 |
+
|
588 |
+
# resolve input folder path
|
589 |
+
in_path = Path(args.input_folder).expanduser().resolve()
|
590 |
+
assert in_path.exists(), f"Folder {str(in_path)} does not exist"
|
591 |
+
|
592 |
+
# output is input if not specified
|
593 |
+
if args.output_folder is None:
|
594 |
+
out_path = in_path
|
595 |
+
|
596 |
+
# find images in input folder
|
597 |
+
im_paths = [
|
598 |
+
p
|
599 |
+
for p in in_path.iterdir()
|
600 |
+
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
|
601 |
+
]
|
602 |
+
assert im_paths, f"No images found in {str(im_paths)}"
|
603 |
+
|
604 |
+
print(f"\nFound {len(im_paths)} images in {str(in_path)}\n")
|
605 |
+
|
606 |
+
# infer and write
|
607 |
+
for i, im_path in enumerate(im_paths):
|
608 |
+
print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name)
|
609 |
+
outs = cg.infer_single(
|
610 |
+
np.array(Image.open(im_path)),
|
611 |
+
args.painter,
|
612 |
+
as_pil_image=True,
|
613 |
+
concats=[
|
614 |
+
"input",
|
615 |
+
"masked_input",
|
616 |
+
"climategan_flood",
|
617 |
+
"stable_copy_flood",
|
618 |
+
],
|
619 |
+
)
|
620 |
+
for k, v in outs.items():
|
621 |
+
name = f"{im_path.stem}---{k}{im_path.suffix}"
|
622 |
+
im = Image.fromarray(uint8(v))
|
623 |
+
im.save(out_path / name)
|
624 |
+
print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n")
|