Spaces:
Runtime error
Runtime error
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501 | |
# thank you @NimaBoscarino | |
import os | |
import re | |
from pathlib import Path | |
from uuid import uuid4 | |
import numpy as np | |
import torch | |
from diffusers import StableDiffusionInpaintPipeline | |
from PIL import Image | |
from skimage.color import rgba2rgb | |
from skimage.transform import resize | |
from climategan.trainer import Trainer | |
def concat_events(output_dict, events, i=None, axis=1): | |
""" | |
Concatenates the `i`th data in `output_dict` according to the keys listed | |
in `events` on dimension `axis`. | |
Args: | |
output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping | |
events to their corresponding data : | |
{k: [HxWxC]} (for i != None) or {k: BxHxWxC}. | |
events (list[str]): output_dict's keys to concatenate. | |
axis (int, optional): Concatenation axis. Defaults to 1. | |
""" | |
cs = [e for e in events if e in output_dict] | |
if i is not None: | |
return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis)) | |
return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis)) | |
def clear(folder): | |
""" | |
Deletes all the images without the inference separator "---" in their name. | |
Args: | |
folder (Union[str, Path]): The folder to clear. | |
""" | |
for i in list(Path(folder).iterdir()): | |
if i.is_file() and "---" in i.stem: | |
i.unlink() | |
def uint8(array, rescale=False): | |
""" | |
convert an array to np.uint8 (does not rescale or anything else than changing dtype) | |
Args: | |
array (np.array): array to modify | |
Returns: | |
np.array(np.uint8): converted array | |
""" | |
if rescale: | |
if array.min() < 0: | |
if array.min() >= -1 and array.max() <= 1: | |
array = (array + 1) / 2 | |
else: | |
raise ValueError( | |
f"Data range mismatch for image: ({array.min()}, {array.max()})" | |
) | |
if array.max() <= 1: | |
array = array * 255 | |
return array.astype(np.uint8) | |
def resize_and_crop(img, to=640): | |
""" | |
Resizes an image so that it keeps the aspect ratio and the smallest dimensions | |
is `to`, then crops this resized image in its center so that the output is `to x to` | |
without aspect ratio distortion | |
Args: | |
img (np.array): np.uint8 255 image | |
Returns: | |
np.array: [0, 1] np.float32 image | |
""" | |
# resize keeping aspect ratio: smallest dim is 640 | |
h, w = img.shape[:2] | |
if h < w: | |
size = (to, int(to * w / h)) | |
else: | |
size = (int(to * h / w), to) | |
r_img = resize(img, size, preserve_range=True, anti_aliasing=True) | |
r_img = uint8(r_img) | |
# crop in the center | |
H, W = r_img.shape[:2] | |
top = (H - to) // 2 | |
left = (W - to) // 2 | |
rc_img = r_img[top : top + to, left : left + to, :] | |
return rc_img / 255.0 | |
def to_m1_p1(img): | |
""" | |
rescales a [0, 1] image to [-1, +1] | |
Args: | |
img (np.array): float32 numpy array of an image in [0, 1] | |
i (int): Index of the image being rescaled | |
Raises: | |
ValueError: If the image is not in [0, 1] | |
Returns: | |
np.array(np.float32): array in [-1, +1] | |
""" | |
if img.min() >= 0 and img.max() <= 1: | |
return (img.astype(np.float32) - 0.5) * 2 | |
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})") | |
# No need to do any timing in this, since it's just for the HF Space | |
class ClimateGAN: | |
def __init__(self, model_path, dev_mode=False) -> None: | |
""" | |
A wrapper for the ClimateGAN model that you can use to generate | |
events from images or folders containing images. | |
Args: | |
model_path (Union[str, Path]): Where to load the Masker from | |
""" | |
torch.set_grad_enabled(False) | |
self.target_size = 640 | |
self._stable_diffusion_is_setup = False | |
self.dev_mode = dev_mode | |
if self.dev_mode: | |
return | |
self.trainer = Trainer.resume_from_path( | |
model_path, | |
setup=True, | |
inference=True, | |
new_exp=None, | |
) | |
self.trainer.G.half() | |
def _setup_stable_diffusion(self): | |
""" | |
Sets up the stable diffusion pipeline for in-painting. | |
Make sure you have accepted the license on the model's card | |
https://huggingface.co/CompVis/stable-diffusion-v1-4 | |
""" | |
if self.dev_mode: | |
return | |
try: | |
self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained( | |
"runwayml/stable-diffusion-inpainting", | |
revision="fp16", | |
torch_dtype=torch.float16, | |
safety_checker=None, | |
use_auth_token=os.environ.get("HF_AUTH_TOKEN"), | |
).to(self.trainer.device) | |
self._stable_diffusion_is_setup = True | |
except Exception as e: | |
print( | |
"\nCould not load stable diffusion model. " | |
+ "Please make sure you have accepted the license on the model's" | |
+ " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n" | |
) | |
raise e | |
def _preprocess_image(self, img): | |
# rgba to rgb | |
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255) | |
# to args.target_size | |
data = resize_and_crop(data, self.target_size) | |
# resize() produces [0, 1] images, rescale to [-1, 1] | |
data = to_m1_p1(data) | |
return data | |
# Does all three inferences at the moment. | |
def infer_single( | |
self, | |
orig_image, | |
painter="both", | |
prompt="An HD picture of a street with dirty water after a heavy flood", | |
concats=[ | |
"input", | |
"masked_input", | |
"climategan_flood", | |
"stable_flood", | |
"stable_copy_flood", | |
], | |
): | |
""" | |
Infers the image with the ClimateGAN model. | |
Importantly (and unlike self.infer_preprocessed_batch), the image is | |
pre-processed by self._preprocess_image before going through the networks. | |
Output dict contains the following keys: | |
- "input": The input image | |
- "mask": The mask used to generate the flood (from ClimateGAN's Masker) | |
- "masked_input": The input image with the mask applied | |
- "climategan_flood": The flooded image generated by ClimateGAN's Painter | |
on the masked input (only if "painter" is "climategan" or "both"). | |
- "stable_flood": The flooded image in-painted by the stable diffusion model | |
from the mask and the input image (only if "painter" is "stable_diffusion" | |
or "both"). | |
- "stable_copy_flood": The flooded image in-painted by the stable diffusion | |
model with its original context pasted back in: | |
y = m * flooded + (1-m) * input | |
(only if "painter" is "stable_diffusion" or "both"). | |
Args: | |
orig_image (Union[str, np.array]): image to infer on. Can be a path to | |
an image which will be read. | |
painter (str, optional): Which painter to use: "climategan", | |
"stable_diffusion" or "both". Defaults to "both". | |
prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
to "An HD picture of a street with dirty water after a heavy flood". | |
concats (list, optional): List of keys in `output` to concatenate together | |
in a new `{original_stem}_concat` image written. Defaults to: | |
["input", "masked_input", "climategan_flood", "stable_flood", | |
"stable_copy_flood"]. | |
Returns: | |
dict: a dictionary containing the output images {k: HxWxC}. C is omitted | |
for masks (HxW). | |
""" | |
if self.dev_mode: | |
return { | |
"input": np.random.randint(0, 255, (640, 640, 3)), | |
"mask": np.random.randint(0, 255, (640, 640)), | |
"masked_input": np.random.randint(0, 255, (640, 640, 3)), | |
"climategan_flood": np.random.randint(0, 255, (640, 640, 3)), | |
"stable_flood": np.random.randint(0, 255, (640, 640, 3)), | |
"stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)), | |
"concat": np.random.randint(0, 255, (640, 640 * 5, 3)), | |
"smog": np.random.randint(0, 255, (640, 640, 3)), | |
"wildfire": np.random.randint(0, 255, (640, 640, 3)), | |
} | |
image_array = ( | |
np.array(Image.open(orig_image)) | |
if isinstance(orig_image, str) | |
else orig_image | |
) | |
image = self._preprocess_image(image_array) | |
output_dict = self.infer_preprocessed_batch( | |
image[None, ...], painter, prompt, concats | |
) | |
return {k: v[0] for k, v in output_dict.items()} | |
def infer_preprocessed_batch( | |
self, | |
images, | |
painter="both", | |
prompt="An HD picture of a street with dirty water after a heavy flood", | |
concats=[ | |
"input", | |
"masked_input", | |
"climategan_flood", | |
"stable_flood", | |
"stable_copy_flood", | |
], | |
): | |
""" | |
Infers ClimateGAN predictions on a batch of preprocessed images. | |
It assumes that each image in the batch has been preprocessed with | |
self._preprocess_image(). | |
Output dict contains the following keys: | |
- "input": The input image | |
- "mask": The mask used to generate the flood (from ClimateGAN's Masker) | |
- "masked_input": The input image with the mask applied | |
- "climategan_flood": The flooded image generated by ClimateGAN's Painter | |
on the masked input (only if "painter" is "climategan" or "both"). | |
- "stable_flood": The flooded image in-painted by the stable diffusion model | |
from the mask and the input image (only if "painter" is "stable_diffusion" | |
or "both"). | |
- "stable_copy_flood": The flooded image in-painted by the stable diffusion | |
model with its original context pasted back in: | |
y = m * flooded + (1-m) * input | |
(only if "painter" is "stable_diffusion" or "both"). | |
Args: | |
images (np.array): A batch of input images BxHxWx3 | |
painter (str, optional): Which painter to use: "climategan", | |
"stable_diffusion" or "both". Defaults to "both". | |
prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
to "An HD picture of a street with dirty water after a heavy flood". | |
concats (list, optional): List of keys in `output` to concatenate together | |
in a new `{original_stem}_concat` image written. Defaults to: | |
["input", "masked_input", "climategan_flood", "stable_flood", | |
"stable_copy_flood"]. | |
Returns: | |
dict: a dictionary containing the output images | |
""" | |
assert painter in [ | |
"both", | |
"stable_diffusion", | |
"climategan", | |
], f"Unknown painter: {painter}" | |
ignore_event = set() | |
if painter == "climategan": | |
ignore_event.add("flood") | |
# Retrieve numpy events as a dict {event: array[BxHxWxC]} | |
outputs = self.trainer.infer_all( | |
images, | |
numpy=True, | |
bin_value=0.5, | |
half=True, | |
ignore_event=ignore_event, | |
return_masks=True, | |
) | |
outputs["input"] = uint8(images, True) | |
# from Bx1xHxW to BxHxWx1 | |
outputs["masked_input"] = outputs["input"] * ( | |
outputs["mask"].squeeze(1)[..., None] == 0 | |
) | |
if painter in {"both", "climategan"}: | |
outputs["climategan_flood"] = outputs.pop("flood") | |
else: | |
del outputs["flood"] | |
if painter != "climategan": | |
if not self._stable_diffusion_is_setup: | |
print("Setting up stable diffusion in-painting pipeline") | |
self._setup_stable_diffusion() | |
mask = outputs["mask"].squeeze(1) | |
input_images = ( | |
torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device) | |
) | |
input_mask = torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device) | |
floods = self.sdip_pipeline( | |
prompt=[prompt] * images.shape[0], | |
image=input_images, | |
mask_image=input_mask, | |
height=640, | |
width=640, | |
num_inference_steps=50, | |
) | |
bin_mask = mask[..., None] > 0 | |
flood = np.stack([np.array(i) for i in floods.images]) | |
copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask) | |
outputs["stable_flood"] = flood | |
outputs["stable_copy_flood"] = copy_flood | |
if concats: | |
outputs["concat"] = concat_events(outputs, concats, axis=2) | |
return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()} | |
def infer_folder( | |
self, | |
folder_path, | |
painter="both", | |
prompt="An HD picture of a street with dirty water after a heavy flood", | |
batch_size=4, | |
concats=[ | |
"input", | |
"masked_input", | |
"climategan_flood", | |
"stable_flood", | |
"stable_copy_flood", | |
], | |
write=True, | |
overwrite=False, | |
): | |
""" | |
Infers the images in a folder with the ClimateGAN model, batching images for | |
inference according to the batch_size. | |
Images must end in .jpg, .jpeg or .png (not case-sensitive). | |
Images must not contain the separator ("---") in their name. | |
Images will be written to disk in the same folder as the input images, with | |
a name that depends on its data, potentially the prompt and a random | |
identifier in case multiple inferences are run in the folder. | |
Output dict contains the following keys: | |
- "input": The input image | |
- "mask": The mask used to generate the flood (from ClimateGAN's Masker) | |
- "masked_input": The input image with the mask applied | |
- "climategan_flood": The flooded image generated by ClimateGAN's Painter | |
on the masked input (only if "painter" is "climategan" or "both"). | |
- "stable_flood": The flooded image in-painted by the stable diffusion model | |
from the mask and the input image (only if "painter" is "stable_diffusion" | |
or "both"). | |
- "stable_copy_flood": The flooded image in-painted by the stable diffusion | |
model with its original context pasted back in: | |
y = m * flooded + (1-m) * input | |
(only if "painter" is "stable_diffusion" or "both"). | |
Args: | |
folder_path (Union[str, Path]): Where to read images from. | |
painter (str, optional): Which painter to use: "climategan", | |
"stable_diffusion" or "both". Defaults to "both". | |
prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
to "An HD picture of a street with dirty water after a heavy flood". | |
batch_size (int, optional): Size of inference batches. Defaults to 4. | |
concats (list, optional): List of keys in `output` to concatenate together | |
in a new `{original_stem}_concat` image written. Defaults to: | |
["input", "masked_input", "climategan_flood", "stable_flood", | |
"stable_copy_flood"]. | |
write (bool, optional): Whether or not to write the outputs to the input | |
folder.Defaults to True. | |
overwrite (Union[bool, str], optional): Whether to overwrite the images or | |
not. If a string is provided, it will be included in the name. | |
Defaults to False. | |
Returns: | |
dict: a dictionary containing the output images | |
""" | |
folder_path = Path(folder_path).expanduser().resolve() | |
assert folder_path.exists(), f"Folder {str(folder_path)} does not exist" | |
assert folder_path.is_dir(), f"{str(folder_path)} is not a directory" | |
im_paths = [ | |
p | |
for p in folder_path.iterdir() | |
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name | |
] | |
assert im_paths, f"No images found in {str(folder_path)}" | |
ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths] | |
batches = [ | |
np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size) | |
] | |
inferences = [ | |
self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches | |
] | |
outputs = { | |
k: [i for e in inferences for i in e[k]] for k in inferences[0].keys() | |
} | |
if write: | |
self.write(outputs, im_paths, painter, overwrite, prompt) | |
return outputs | |
def write( | |
self, | |
outputs, | |
im_paths, | |
painter="both", | |
overwrite=False, | |
prompt="", | |
): | |
""" | |
Writes the outputs of the inference to disk, in the input folder. | |
Images will be named like: | |
f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}" | |
`painter_type` is either "climategan" or f"stable_diffusion_{prompt}" | |
Args: | |
outputs (_type_): The inference procedure's output dict. | |
im_paths (list[Path]): The list of input images paths. | |
painter (str, optional): Which painter was used. Defaults to "both". | |
overwrite (bool, optional): Whether to overwrite the images or not. | |
If a string is provided, it will be included in the name. | |
If False, a random identifier will be added to the name. | |
Defaults to False. | |
prompt (str, optional): The prompt used to guide the diffusion. Defaults | |
to "". | |
""" | |
prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower() | |
overwrite_prefix = "" | |
if not overwrite: | |
overwrite_prefix = str(uuid4())[:8] | |
print("Writing events with prefix", overwrite_prefix) | |
else: | |
if isinstance(overwrite, str): | |
overwrite_prefix = overwrite | |
print("Writing events with prefix", overwrite_prefix) | |
# for each image, for each event/data type | |
for i, im_path in enumerate(im_paths): | |
for event, ims in outputs.items(): | |
painter_prefix = "" | |
if painter == "climategan" and event == "flood": | |
painter_prefix = "climategan" | |
elif ( | |
painter in {"stable_diffusion", "both"} and event == "stable_flood" | |
): | |
painter_prefix = f"_stable_{prompt}" | |
elif painter == "both" and event == "climategan_flood": | |
painter_prefix = "" | |
im = ims[i] | |
im = Image.fromarray(uint8(im)) | |
imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}" | |
im.save(im_path.parent / (imstem + im_path.suffix)) | |