# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501 # thank you @NimaBoscarino import os import re from pathlib import Path from uuid import uuid4 import numpy as np import torch from diffusers import StableDiffusionInpaintPipeline from PIL import Image from skimage.color import rgba2rgb from skimage.transform import resize from climategan.trainer import Trainer def concat_events(output_dict, events, i=None, axis=1): """ Concatenates the `i`th data in `output_dict` according to the keys listed in `events` on dimension `axis`. Args: output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping events to their corresponding data : {k: [HxWxC]} (for i != None) or {k: BxHxWxC}. events (list[str]): output_dict's keys to concatenate. axis (int, optional): Concatenation axis. Defaults to 1. """ cs = [e for e in events if e in output_dict] if i is not None: return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis)) return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis)) def clear(folder): """ Deletes all the images without the inference separator "---" in their name. Args: folder (Union[str, Path]): The folder to clear. """ for i in list(Path(folder).iterdir()): if i.is_file() and "---" in i.stem: i.unlink() def uint8(array, rescale=False): """ convert an array to np.uint8 (does not rescale or anything else than changing dtype) Args: array (np.array): array to modify Returns: np.array(np.uint8): converted array """ if rescale: if array.min() < 0: if array.min() >= -1 and array.max() <= 1: array = (array + 1) / 2 else: raise ValueError( f"Data range mismatch for image: ({array.min()}, {array.max()})" ) if array.max() <= 1: array = array * 255 return array.astype(np.uint8) def resize_and_crop(img, to=640): """ Resizes an image so that it keeps the aspect ratio and the smallest dimensions is `to`, then crops this resized image in its center so that the output is `to x to` without aspect ratio distortion Args: img (np.array): np.uint8 255 image Returns: np.array: [0, 1] np.float32 image """ # resize keeping aspect ratio: smallest dim is 640 h, w = img.shape[:2] if h < w: size = (to, int(to * w / h)) else: size = (int(to * h / w), to) r_img = resize(img, size, preserve_range=True, anti_aliasing=True) r_img = uint8(r_img) # crop in the center H, W = r_img.shape[:2] top = (H - to) // 2 left = (W - to) // 2 rc_img = r_img[top : top + to, left : left + to, :] return rc_img / 255.0 def to_m1_p1(img): """ rescales a [0, 1] image to [-1, +1] Args: img (np.array): float32 numpy array of an image in [0, 1] i (int): Index of the image being rescaled Raises: ValueError: If the image is not in [0, 1] Returns: np.array(np.float32): array in [-1, +1] """ if img.min() >= 0 and img.max() <= 1: return (img.astype(np.float32) - 0.5) * 2 raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})") # No need to do any timing in this, since it's just for the HF Space class ClimateGAN: def __init__(self, model_path, dev_mode=False) -> None: """ A wrapper for the ClimateGAN model that you can use to generate events from images or folders containing images. Args: model_path (Union[str, Path]): Where to load the Masker from """ torch.set_grad_enabled(False) self.target_size = 640 self._stable_diffusion_is_setup = False self.dev_mode = dev_mode if self.dev_mode: return self.trainer = Trainer.resume_from_path( model_path, setup=True, inference=True, new_exp=None, ) self.trainer.G.half() def _setup_stable_diffusion(self): """ Sets up the stable diffusion pipeline for in-painting. Make sure you have accepted the license on the model's card https://huggingface.co/CompVis/stable-diffusion-v1-4 """ if self.dev_mode: return try: self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, safety_checker=None, use_auth_token=os.environ.get("HF_AUTH_TOKEN"), ).to(self.trainer.device) self._stable_diffusion_is_setup = True except Exception as e: print( "\nCould not load stable diffusion model. " + "Please make sure you have accepted the license on the model's" + " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n" ) raise e def _preprocess_image(self, img): # rgba to rgb data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255) # to args.target_size data = resize_and_crop(data, self.target_size) # resize() produces [0, 1] images, rescale to [-1, 1] data = to_m1_p1(data) return data # Does all three inferences at the moment. def infer_single( self, orig_image, painter="both", prompt="An HD picture of a street with dirty water after a heavy flood", concats=[ "input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood", ], ): """ Infers the image with the ClimateGAN model. Importantly (and unlike self.infer_preprocessed_batch), the image is pre-processed by self._preprocess_image before going through the networks. Output dict contains the following keys: - "input": The input image - "mask": The mask used to generate the flood (from ClimateGAN's Masker) - "masked_input": The input image with the mask applied - "climategan_flood": The flooded image generated by ClimateGAN's Painter on the masked input (only if "painter" is "climategan" or "both"). - "stable_flood": The flooded image in-painted by the stable diffusion model from the mask and the input image (only if "painter" is "stable_diffusion" or "both"). - "stable_copy_flood": The flooded image in-painted by the stable diffusion model with its original context pasted back in: y = m * flooded + (1-m) * input (only if "painter" is "stable_diffusion" or "both"). Args: orig_image (Union[str, np.array]): image to infer on. Can be a path to an image which will be read. painter (str, optional): Which painter to use: "climategan", "stable_diffusion" or "both". Defaults to "both". prompt (str, optional): The prompt used to guide the diffusion. Defaults to "An HD picture of a street with dirty water after a heavy flood". concats (list, optional): List of keys in `output` to concatenate together in a new `{original_stem}_concat` image written. Defaults to: ["input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood"]. Returns: dict: a dictionary containing the output images {k: HxWxC}. C is omitted for masks (HxW). """ if self.dev_mode: return { "input": np.random.randint(0, 255, (640, 640, 3)), "mask": np.random.randint(0, 255, (640, 640)), "masked_input": np.random.randint(0, 255, (640, 640, 3)), "climategan_flood": np.random.randint(0, 255, (640, 640, 3)), "stable_flood": np.random.randint(0, 255, (640, 640, 3)), "stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)), "concat": np.random.randint(0, 255, (640, 640 * 5, 3)), "smog": np.random.randint(0, 255, (640, 640, 3)), "wildfire": np.random.randint(0, 255, (640, 640, 3)), } image_array = ( np.array(Image.open(orig_image)) if isinstance(orig_image, str) else orig_image ) image = self._preprocess_image(image_array) output_dict = self.infer_preprocessed_batch( image[None, ...], painter, prompt, concats ) return {k: v[0] for k, v in output_dict.items()} def infer_preprocessed_batch( self, images, painter="both", prompt="An HD picture of a street with dirty water after a heavy flood", concats=[ "input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood", ], ): """ Infers ClimateGAN predictions on a batch of preprocessed images. It assumes that each image in the batch has been preprocessed with self._preprocess_image(). Output dict contains the following keys: - "input": The input image - "mask": The mask used to generate the flood (from ClimateGAN's Masker) - "masked_input": The input image with the mask applied - "climategan_flood": The flooded image generated by ClimateGAN's Painter on the masked input (only if "painter" is "climategan" or "both"). - "stable_flood": The flooded image in-painted by the stable diffusion model from the mask and the input image (only if "painter" is "stable_diffusion" or "both"). - "stable_copy_flood": The flooded image in-painted by the stable diffusion model with its original context pasted back in: y = m * flooded + (1-m) * input (only if "painter" is "stable_diffusion" or "both"). Args: images (np.array): A batch of input images BxHxWx3 painter (str, optional): Which painter to use: "climategan", "stable_diffusion" or "both". Defaults to "both". prompt (str, optional): The prompt used to guide the diffusion. Defaults to "An HD picture of a street with dirty water after a heavy flood". concats (list, optional): List of keys in `output` to concatenate together in a new `{original_stem}_concat` image written. Defaults to: ["input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood"]. Returns: dict: a dictionary containing the output images """ assert painter in [ "both", "stable_diffusion", "climategan", ], f"Unknown painter: {painter}" ignore_event = set() if painter == "climategan": ignore_event.add("flood") # Retrieve numpy events as a dict {event: array[BxHxWxC]} outputs = self.trainer.infer_all( images, numpy=True, bin_value=0.5, half=True, ignore_event=ignore_event, return_masks=True, ) outputs["input"] = uint8(images, True) # from Bx1xHxW to BxHxWx1 outputs["masked_input"] = outputs["input"] * ( outputs["mask"].squeeze(1)[..., None] == 0 ) if painter in {"both", "climategan"}: outputs["climategan_flood"] = outputs.pop("flood") else: del outputs["flood"] if painter != "climategan": if not self._stable_diffusion_is_setup: print("Setting up stable diffusion in-painting pipeline") self._setup_stable_diffusion() mask = outputs["mask"].squeeze(1) input_images = ( torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device) ) input_mask = torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device) floods = self.sdip_pipeline( prompt=[prompt] * images.shape[0], image=input_images, mask_image=input_mask, height=640, width=640, num_inference_steps=50, ) bin_mask = mask[..., None] > 0 flood = np.stack([np.array(i) for i in floods.images]) copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask) outputs["stable_flood"] = flood outputs["stable_copy_flood"] = copy_flood if concats: outputs["concat"] = concat_events(outputs, concats, axis=2) return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()} def infer_folder( self, folder_path, painter="both", prompt="An HD picture of a street with dirty water after a heavy flood", batch_size=4, concats=[ "input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood", ], write=True, overwrite=False, ): """ Infers the images in a folder with the ClimateGAN model, batching images for inference according to the batch_size. Images must end in .jpg, .jpeg or .png (not case-sensitive). Images must not contain the separator ("---") in their name. Images will be written to disk in the same folder as the input images, with a name that depends on its data, potentially the prompt and a random identifier in case multiple inferences are run in the folder. Output dict contains the following keys: - "input": The input image - "mask": The mask used to generate the flood (from ClimateGAN's Masker) - "masked_input": The input image with the mask applied - "climategan_flood": The flooded image generated by ClimateGAN's Painter on the masked input (only if "painter" is "climategan" or "both"). - "stable_flood": The flooded image in-painted by the stable diffusion model from the mask and the input image (only if "painter" is "stable_diffusion" or "both"). - "stable_copy_flood": The flooded image in-painted by the stable diffusion model with its original context pasted back in: y = m * flooded + (1-m) * input (only if "painter" is "stable_diffusion" or "both"). Args: folder_path (Union[str, Path]): Where to read images from. painter (str, optional): Which painter to use: "climategan", "stable_diffusion" or "both". Defaults to "both". prompt (str, optional): The prompt used to guide the diffusion. Defaults to "An HD picture of a street with dirty water after a heavy flood". batch_size (int, optional): Size of inference batches. Defaults to 4. concats (list, optional): List of keys in `output` to concatenate together in a new `{original_stem}_concat` image written. Defaults to: ["input", "masked_input", "climategan_flood", "stable_flood", "stable_copy_flood"]. write (bool, optional): Whether or not to write the outputs to the input folder.Defaults to True. overwrite (Union[bool, str], optional): Whether to overwrite the images or not. If a string is provided, it will be included in the name. Defaults to False. Returns: dict: a dictionary containing the output images """ folder_path = Path(folder_path).expanduser().resolve() assert folder_path.exists(), f"Folder {str(folder_path)} does not exist" assert folder_path.is_dir(), f"{str(folder_path)} is not a directory" im_paths = [ p for p in folder_path.iterdir() if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name ] assert im_paths, f"No images found in {str(folder_path)}" ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths] batches = [ np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size) ] inferences = [ self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches ] outputs = { k: [i for e in inferences for i in e[k]] for k in inferences[0].keys() } if write: self.write(outputs, im_paths, painter, overwrite, prompt) return outputs def write( self, outputs, im_paths, painter="both", overwrite=False, prompt="", ): """ Writes the outputs of the inference to disk, in the input folder. Images will be named like: f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}" `painter_type` is either "climategan" or f"stable_diffusion_{prompt}" Args: outputs (_type_): The inference procedure's output dict. im_paths (list[Path]): The list of input images paths. painter (str, optional): Which painter was used. Defaults to "both". overwrite (bool, optional): Whether to overwrite the images or not. If a string is provided, it will be included in the name. If False, a random identifier will be added to the name. Defaults to False. prompt (str, optional): The prompt used to guide the diffusion. Defaults to "". """ prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower() overwrite_prefix = "" if not overwrite: overwrite_prefix = str(uuid4())[:8] print("Writing events with prefix", overwrite_prefix) else: if isinstance(overwrite, str): overwrite_prefix = overwrite print("Writing events with prefix", overwrite_prefix) # for each image, for each event/data type for i, im_path in enumerate(im_paths): for event, ims in outputs.items(): painter_prefix = "" if painter == "climategan" and event == "flood": painter_prefix = "climategan" elif ( painter in {"stable_diffusion", "both"} and event == "stable_flood" ): painter_prefix = f"_stable_{prompt}" elif painter == "both" and event == "climategan_flood": painter_prefix = "" im = ims[i] im = Image.fromarray(uint8(im)) imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}" im.save(im_path.parent / (imstem + im_path.suffix))