Spaces:
Runtime error
Runtime error
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501 | |
# thank you @NimaBoscarino | |
import os | |
import re | |
from pathlib import Path | |
from uuid import uuid4 | |
from minydra import resolved_args | |
import numpy as np | |
import torch | |
from diffusers import StableDiffusionInpaintPipeline | |
from PIL import Image | |
from skimage.color import rgba2rgb | |
from skimage.transform import resize | |
from climategan.trainer import Trainer | |
CUDA = torch.cuda.is_available() | |
def concat_events(output_dict, events, i=None, axis=1): | |
""" | |
Concatenates the `i`th data in `output_dict` according to the keys listed | |
in `events` on dimension `axis`. | |
Args: | |
output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping | |
events to their corresponding data : | |
{k: [HxWxC]} (for i != None) or {k: BxHxWxC}. | |
events (list[str]): output_dict's keys to concatenate. | |
axis (int, optional): Concatenation axis. Defaults to 1. | |
""" | |
cs = [e for e in events if e in output_dict] | |
if i is not None: | |
return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis)) | |
return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis)) | |
def clear(folder): | |
""" | |
Deletes all the images without the inference separator "---" in their name. | |
Args: | |
folder (Union[str, Path]): The folder to clear. | |
""" | |
for i in list(Path(folder).iterdir()): | |
if i.is_file() and "---" in i.stem: | |
i.unlink() | |
def uint8(array, rescale=False): | |
""" | |
convert an array to np.uint8 (does not rescale or anything else than changing dtype) | |
Args: | |
array (np.array): array to modify | |
Returns: | |
np.array(np.uint8): converted array | |
""" | |
if rescale: | |
if array.min() < 0: | |
if array.min() >= -1 and array.max() <= 1: | |
array = (array + 1) / 2 | |
else: | |
raise ValueError( | |
f"Data range mismatch for image: ({array.min()}, {array.max()})" | |
) | |
if array.max() <= 1: | |
array = array * 255 | |
return array.astype(np.uint8) | |
def resize_and_crop(img, to=640): | |
""" | |
Resizes an image so that it keeps the aspect ratio and the smallest dimensions | |
is `to`, then crops this resized image in its center so that the output is `to x to` | |
without aspect ratio distortion | |
Args: | |
img (np.array): np.uint8 255 image | |
Returns: | |
np.array: [0, 1] np.float32 image | |
""" | |
# resize keeping aspect ratio: smallest dim is 640 | |
h, w = img.shape[:2] | |
if h < w: | |
size = (to, int(to * w / h)) | |
else: | |
size = (int(to * h / w), to) | |
r_img = resize(img, size, preserve_range=True, anti_aliasing=True) | |
r_img = uint8(r_img) | |
# crop in the center | |
H, W = r_img.shape[:2] | |
top = (H - to) // 2 | |
left = (W - to) // 2 | |
rc_img = r_img[top : top + to, left : left + to, :] | |
return rc_img / 255.0 | |
def to_m1_p1(img): | |
""" | |
rescales a [0, 1] image to [-1, +1] | |
Args: | |
img (np.array): float32 numpy array of an image in [0, 1] | |
i (int): Index of the image being rescaled | |
Raises: | |
ValueError: If the image is not in [0, 1] | |
Returns: | |
np.array(np.float32): array in [-1, +1] | |
""" | |
if img.min() >= 0 and img.max() <= 1: | |
return (img.astype(np.float32) - 0.5) * 2 | |
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})") | |
# No need to do any timing in this, since it's just for the HF Space | |
class ClimateGAN: | |
def __init__(self, model_path, dev_mode=False) -> None: | |
""" | |
A wrapper for the ClimateGAN model that you can use to generate | |
events from images or folders containing images. | |
Args: | |
model_path (Union[str, Path]): Where to load the Masker from | |
""" | |
torch.set_grad_enabled(False) | |
self.target_size = 640 | |
self._stable_diffusion_is_setup = False | |
self.dev_mode = dev_mode | |
if self.dev_mode: | |
return | |
self.trainer = Trainer.resume_from_path( | |
model_path, | |
setup=True, | |
inference=True, | |
new_exp=None, | |
) | |
if CUDA: | |
self.trainer.G.half() | |
def _setup_stable_diffusion(self): | |
""" | |
Sets up the stable diffusion pipeline for in-painting. | |
Make sure you have accepted the license on the model's card | |
https://huggingface.co/CompVis/stable-diffusion-v1-4 | |
""" | |
if self.dev_mode: | |
return | |
try: | |
self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
revision="fp16" if CUDA else "main", | |
torch_dtype=torch.float16 if CUDA else torch.float32, | |
safety_checker=None, | |
use_auth_token=os.environ.get("HF_AUTH_TOKEN"), | |
).to(self.trainer.device) | |
self._stable_diffusion_is_setup = True | |
except Exception as e: | |
print( | |
"\nCould not load stable diffusion model. " | |
+ "Please make sure you have accepted the license on the model's" | |
+ " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n" | |
) | |
raise e | |
def _preprocess_image(self, img): | |
""" | |
Turns a HxWxC uint8 numpy array into a 640x640x3 float32 numpy array | |
in [-1, 1]. | |
Args: | |
img (np.array): Image to resize crop and rescale | |
Returns: | |
np.array: Resized, cropped and rescaled image | |
""" | |
# rgba to rgb | |
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255) | |
# to args.target_size | |
data = resize_and_crop(data, self.target_size) | |
# resize() produces [0, 1] images, rescale to [-1, 1] | |
data = to_m1_p1(data) | |
return data | |
# Does all three inferences at the moment. | |
def infer_single( | |
self, | |
orig_image, | |
painter="both", | |
prompt="An HD picture of a street with dirty water after a heavy flood", | |
concats=[ | |
"input", | |
"masked_input", | |
"climategan_flood", | |
"stable_flood", | |
"stable_copy_flood", | |
], | |
as_pil_image=False, | |
): | |
""" | |
Infers the image with the ClimateGAN model. | |
Importantly (and unlike self.infer_preprocessed_batch), the image is | |
pre-processed by self._preprocess_image before going through the networks. | |
Output dict contains the following keys: | |
- "input": The input image | |
- "mask": The mask used to generate the flood (from ClimateGAN's Masker) | |
- "masked_input": The input image with the mask applied | |
- "climategan_flood": The flooded image generated by ClimateGAN's Painter | |
on the masked input (only if "painter" is "climategan" or "both"). | |
- "stable_flood": The flooded image in-painted by the stable diffusion model | |
from the mask and the input image (only if "painter" is "stable_diffusion" | |
or "both"). | |
- "stable_copy_flood": The flooded image in-painted by the stable diffusion | |
model with its original context pasted back in: | |
y = m * flooded + (1-m) * input | |
(only if "painter" is "stable_diffusion" or "both"). | |
Args: | |
orig_image (Union[str, np.array]): image to infer on. Can be a path to | |
an image which will be read. | |
painter (str, optional): Which painter to use: "climategan", | |
"stable_diffusion" or "both". Defaults to "both". | |
prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
to "An HD picture of a street with dirty water after a heavy flood". | |
concats (list, optional): List of keys in `output` to concatenate together | |
in a new `{original_stem}_concat` image written. Defaults to: | |
["input", "masked_input", "climategan_flood", "stable_flood", | |
"stable_copy_flood"]. | |
Returns: | |
dict: a dictionary containing the output images {k: HxWxC}. C is omitted | |
for masks (HxW). | |
""" | |
if self.dev_mode: | |
return { | |
"input": orig_image, | |
"mask": np.random.randint(0, 255, (640, 640)), | |
"masked_input": np.random.randint(0, 255, (640, 640, 3)), | |
"climategan_flood": np.random.randint(0, 255, (640, 640, 3)), | |
"stable_flood": np.random.randint(0, 255, (640, 640, 3)), | |
"stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)), | |
"concat": np.random.randint(0, 255, (640, 640 * 5, 3)), | |
"smog": np.random.randint(0, 255, (640, 640, 3)), | |
"wildfire": np.random.randint(0, 255, (640, 640, 3)), | |
"depth": np.random.randint(0, 255, (640, 640, 1)), | |
"segmentation": np.random.randint(0, 255, (640, 640, 3)), | |
} | |
return | |
image_array = ( | |
np.array(Image.open(orig_image)) | |
if isinstance(orig_image, str) | |
else orig_image | |
) | |
pil_image = None | |
if as_pil_image: | |
pil_image = Image.fromarray(image_array) | |
print("Preprocessing image") | |
image = self._preprocess_image(image_array) | |
output_dict = self.infer_preprocessed_batch( | |
images=image[None, ...], | |
painter=painter, | |
prompt=prompt, | |
concats=concats, | |
pil_image=pil_image, | |
) | |
print("Inference done") | |
return {k: v[0] for k, v in output_dict.items()} | |
def infer_preprocessed_batch( | |
self, | |
images, | |
painter="both", | |
prompt="An HD picture of a street with dirty water after a heavy flood", | |
concats=[ | |
"input", | |
"masked_input", | |
"climategan_flood", | |
"stable_flood", | |
"stable_copy_flood", | |
], | |
pil_image=None, | |
): | |
""" | |
Infers ClimateGAN predictions on a batch of preprocessed images. | |
It assumes that each image in the batch has been preprocessed with | |
self._preprocess_image(). | |
Output dict contains the following keys: | |
- "input": The input image | |
- "mask": The mask used to generate the flood (from ClimateGAN's Masker) | |
- "masked_input": The input image with the mask applied | |
- "climategan_flood": The flooded image generated by ClimateGAN's Painter | |
on the masked input (only if "painter" is "climategan" or "both"). | |
- "stable_flood": The flooded image in-painted by the stable diffusion model | |
from the mask and the input image (only if "painter" is "stable_diffusion" | |
or "both"). | |
- "stable_copy_flood": The flooded image in-painted by the stable diffusion | |
model with its original context pasted back in: | |
y = m * flooded + (1-m) * input | |
(only if "painter" is "stable_diffusion" or "both"). | |
Args: | |
images (np.array): A batch of input images BxHxWx3 | |
painter (str, optional): Which painter to use: "climategan", | |
"stable_diffusion" or "both". Defaults to "both". | |
prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
to "An HD picture of a street with dirty water after a heavy flood". | |
concats (list, optional): List of keys in `output` to concatenate together | |
in a new `{original_stem}_concat` image written. Defaults to: | |
["input", "masked_input", "climategan_flood", "stable_flood", | |
"stable_copy_flood"]. | |
pil_image (PIL.Image, optional): The original PIL image. If provided, | |
will be used for a single inference (batch_size=1) | |
Returns: | |
dict: a dictionary containing the output images | |
""" | |
assert painter in [ | |
"both", | |
"stable_diffusion", | |
"climategan", | |
], f"Unknown painter: {painter}" | |
ignore_event = set() | |
if painter == "stable_diffusion": | |
ignore_event.add("flood") | |
if pil_image is not None: | |
print("Warning: `pil_image` has been provided, it will override `images`") | |
images = self._preprocess_image(np.array(pil_image))[None, ...] | |
pil_image = Image.fromarray(((images[0] + 1) / 2 * 255).astype(np.uint8)) | |
# Retrieve numpy events as a dict {event: array[BxHxWxC]} | |
print("Inferring ClimateGAN events") | |
outputs = self.trainer.infer_all( | |
images, | |
numpy=True, | |
bin_value=0.5, | |
half=CUDA, | |
ignore_event=ignore_event, | |
return_intermediates=True, | |
) | |
outputs["input"] = uint8(images, True) | |
# from Bx1xHxW to BxHxWx1 | |
outputs["masked_input"] = outputs["input"] * ( | |
outputs["mask"].squeeze(1)[..., None] == 0 | |
) | |
if painter in {"both", "climategan"}: | |
outputs["climategan_flood"] = outputs.pop("flood") | |
else: | |
del outputs["flood"] | |
if painter != "climategan": | |
if not self._stable_diffusion_is_setup: | |
print("Setting up stable diffusion in-painting pipeline") | |
self._setup_stable_diffusion() | |
mask = outputs["mask"].squeeze(1) | |
input_images = ( | |
torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device) | |
if pil_image is None | |
else pil_image | |
) | |
input_mask = ( | |
torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device) | |
if pil_image is None | |
else Image.fromarray(mask[0]) | |
) | |
print("Inferring stable diffusion in-painting for 50 steps") | |
floods = self.sdip_pipeline( | |
prompt=[prompt] * images.shape[0], | |
image=input_images, | |
mask_image=input_mask, | |
height=640, | |
width=640, | |
num_inference_steps=50, | |
) | |
print("Stable diffusion in-painting done") | |
bin_mask = mask[..., None] > 0 | |
flood = np.stack([np.array(i) for i in floods.images]) | |
copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask) | |
outputs["stable_flood"] = flood | |
outputs["stable_copy_flood"] = copy_flood | |
if concats: | |
print("Concatenating flood images") | |
outputs["concat"] = concat_events(outputs, concats, axis=2) | |
return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()} | |
def infer_folder( | |
self, | |
folder_path, | |
painter="both", | |
prompt="An HD picture of a street with dirty water after a heavy flood", | |
batch_size=4, | |
concats=[ | |
"input", | |
"masked_input", | |
"climategan_flood", | |
"stable_flood", | |
"stable_copy_flood", | |
], | |
write=True, | |
overwrite=False, | |
): | |
""" | |
Infers the images in a folder with the ClimateGAN model, batching images for | |
inference according to the batch_size. | |
Images must end in .jpg, .jpeg or .png (not case-sensitive). | |
Images must not contain the separator ("---") in their name. | |
Images will be written to disk in the same folder as the input images, with | |
a name that depends on its data, potentially the prompt and a random | |
identifier in case multiple inferences are run in the folder. | |
Output dict contains the following keys: | |
- "input": The input image | |
- "mask": The mask used to generate the flood (from ClimateGAN's Masker) | |
- "masked_input": The input image with the mask applied | |
- "climategan_flood": The flooded image generated by ClimateGAN's Painter | |
on the masked input (only if "painter" is "climategan" or "both"). | |
- "stable_flood": The flooded image in-painted by the stable diffusion model | |
from the mask and the input image (only if "painter" is "stable_diffusion" | |
or "both"). | |
- "stable_copy_flood": The flooded image in-painted by the stable diffusion | |
model with its original context pasted back in: | |
y = m * flooded + (1-m) * input | |
(only if "painter" is "stable_diffusion" or "both"). | |
Args: | |
folder_path (Union[str, Path]): Where to read images from. | |
painter (str, optional): Which painter to use: "climategan", | |
"stable_diffusion" or "both". Defaults to "both". | |
prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
to "An HD picture of a street with dirty water after a heavy flood". | |
batch_size (int, optional): Size of inference batches. Defaults to 4. | |
concats (list, optional): List of keys in `output` to concatenate together | |
in a new `{original_stem}_concat` image written. Defaults to: | |
["input", "masked_input", "climategan_flood", "stable_flood", | |
"stable_copy_flood"]. | |
write (bool, optional): Whether or not to write the outputs to the input | |
folder.Defaults to True. | |
overwrite (Union[bool, str], optional): Whether to overwrite the images or | |
not. If a string is provided, it will be included in the name. | |
Defaults to False. | |
Returns: | |
dict: a dictionary containing the output images | |
""" | |
folder_path = Path(folder_path).expanduser().resolve() | |
assert folder_path.exists(), f"Folder {str(folder_path)} does not exist" | |
assert folder_path.is_dir(), f"{str(folder_path)} is not a directory" | |
im_paths = [ | |
p | |
for p in folder_path.iterdir() | |
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name | |
] | |
assert im_paths, f"No images found in {str(folder_path)}" | |
ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths] | |
batches = [ | |
np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size) | |
] | |
inferences = [ | |
self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches | |
] | |
outputs = { | |
k: [i for e in inferences for i in e[k]] for k in inferences[0].keys() | |
} | |
if write: | |
self.write(outputs, im_paths, painter, overwrite, prompt) | |
return outputs | |
def write( | |
self, | |
outputs, | |
im_paths, | |
painter="both", | |
overwrite=False, | |
prompt="", | |
): | |
""" | |
Writes the outputs of the inference to disk, in the input folder. | |
Images will be named like: | |
f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}" | |
`painter_type` is either "climategan" or f"stable_diffusion_{prompt}" | |
Args: | |
outputs (_type_): The inference procedure's output dict. | |
im_paths (list[Path]): The list of input images paths. | |
painter (str, optional): Which painter was used. Defaults to "both". | |
overwrite (bool, optional): Whether to overwrite the images or not. | |
If a string is provided, it will be included in the name. | |
If False, a random identifier will be added to the name. | |
Defaults to False. | |
prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
to "". | |
""" | |
prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower() | |
overwrite_prefix = "" | |
if not overwrite: | |
overwrite_prefix = str(uuid4())[:8] | |
print("Writing events with prefix", overwrite_prefix) | |
else: | |
if isinstance(overwrite, str): | |
overwrite_prefix = overwrite | |
print("Writing events with prefix", overwrite_prefix) | |
# for each image, for each event/data type | |
for i, im_path in enumerate(im_paths): | |
for event, ims in outputs.items(): | |
painter_prefix = "" | |
if painter == "climategan" and event == "flood": | |
painter_prefix = "climategan" | |
elif ( | |
painter in {"stable_diffusion", "both"} and event == "stable_flood" | |
): | |
painter_prefix = f"_stable_{prompt}" | |
elif painter == "both" and event == "climategan_flood": | |
painter_prefix = "" | |
im = ims[i] | |
im = Image.fromarray(uint8(im)) | |
imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}" | |
im.save(im_path.parent / (imstem + im_path.suffix)) | |
if __name__ == "__main__": | |
print("Run `$ python climategan_wrapper.py help` for usage instructions\n") | |
# parse arguments | |
args = resolved_args( | |
defaults={ | |
"input_folder": None, | |
"output_folder": None, | |
"painter": "both", | |
"help": False, | |
} | |
) | |
# print help | |
if args.help: | |
print( | |
"Usage: python inference.py input_folder=/path/to/folder\n" | |
+ "By default inferences will be stored in the input folder.\n" | |
+ "Add `output_folder=/path/to/folder` for a different output folder.\n" | |
+ "By default, both ClimateGAN and Stable Diffusion will be used." | |
+ "Change this by adding `painter=climategan` or" | |
+ " `painter=stable_diffusion`.\n" | |
+ "Make sure you have agreed to the terms of use for the models." | |
+ "In particular, visit SD's model card to agree to the terms of use:" | |
+ " https://huggingface.co/runwayml/stable-diffusion-inpainting" | |
) | |
# print args | |
args.pretty_print() | |
# load models | |
cg = ClimateGAN("models/climategan") | |
# check painter type | |
assert args.painter in { | |
"climategan", | |
"stable_diffusion", | |
"both", | |
}, ( | |
f"Unknown painter {args.painter}. " | |
+ "Allowed values are 'climategan', 'stable_diffusion' and 'both'." | |
) | |
# load SD pipeline if need be | |
if args.painter != "climate_gan": | |
cg._setup_stable_diffusion() | |
# resolve input folder path | |
in_path = Path(args.input_folder).expanduser().resolve() | |
assert in_path.exists(), f"Folder {str(in_path)} does not exist" | |
# output is input if not specified | |
if args.output_folder is None: | |
out_path = in_path | |
# find images in input folder | |
im_paths = [ | |
p | |
for p in in_path.iterdir() | |
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name | |
] | |
assert im_paths, f"No images found in {str(im_paths)}" | |
print(f"\nFound {len(im_paths)} images in {str(in_path)}\n") | |
# infer and write | |
for i, im_path in enumerate(im_paths): | |
print(">>> Processing", f"{i}/{len(im_paths)}", im_path.name) | |
outs = cg.infer_single( | |
np.array(Image.open(im_path)), | |
args.painter, | |
as_pil_image=True, | |
concats=[ | |
"input", | |
"masked_input", | |
"climategan_flood", | |
"stable_copy_flood", | |
], | |
) | |
for k, v in outs.items(): | |
name = f"{im_path.stem}---{k}{im_path.suffix}" | |
im = Image.fromarray(uint8(v)) | |
im.save(out_path / name) | |
print(">>> Done", f"{i}/{len(im_paths)}", im_path.name, end="\n\n") | |