climateGAN / climategan_wrapper.py
vict0rsch's picture
add `HF_AUTH_TOKEN`
b7cb85c
raw
history blame
19.4 kB
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
# thank you @NimaBoscarino
import os
import re
from pathlib import Path
from uuid import uuid4
import numpy as np
import torch
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
from skimage.color import rgba2rgb
from skimage.transform import resize
from climategan.trainer import Trainer
def concat_events(output_dict, events, i=None, axis=1):
"""
Concatenates the `i`th data in `output_dict` according to the keys listed
in `events` on dimension `axis`.
Args:
output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping
events to their corresponding data :
{k: [HxWxC]} (for i != None) or {k: BxHxWxC}.
events (list[str]): output_dict's keys to concatenate.
axis (int, optional): Concatenation axis. Defaults to 1.
"""
cs = [e for e in events if e in output_dict]
if i is not None:
return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis))
return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis))
def clear(folder):
"""
Deletes all the images without the inference separator "---" in their name.
Args:
folder (Union[str, Path]): The folder to clear.
"""
for i in list(Path(folder).iterdir()):
if i.is_file() and "---" in i.stem:
i.unlink()
def uint8(array, rescale=False):
"""
convert an array to np.uint8 (does not rescale or anything else than changing dtype)
Args:
array (np.array): array to modify
Returns:
np.array(np.uint8): converted array
"""
if rescale:
if array.min() < 0:
if array.min() >= -1 and array.max() <= 1:
array = (array + 1) / 2
else:
raise ValueError(
f"Data range mismatch for image: ({array.min()}, {array.max()})"
)
if array.max() <= 1:
array = array * 255
return array.astype(np.uint8)
def resize_and_crop(img, to=640):
"""
Resizes an image so that it keeps the aspect ratio and the smallest dimensions
is `to`, then crops this resized image in its center so that the output is `to x to`
without aspect ratio distortion
Args:
img (np.array): np.uint8 255 image
Returns:
np.array: [0, 1] np.float32 image
"""
# resize keeping aspect ratio: smallest dim is 640
h, w = img.shape[:2]
if h < w:
size = (to, int(to * w / h))
else:
size = (int(to * h / w), to)
r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
r_img = uint8(r_img)
# crop in the center
H, W = r_img.shape[:2]
top = (H - to) // 2
left = (W - to) // 2
rc_img = r_img[top : top + to, left : left + to, :]
return rc_img / 255.0
def to_m1_p1(img):
"""
rescales a [0, 1] image to [-1, +1]
Args:
img (np.array): float32 numpy array of an image in [0, 1]
i (int): Index of the image being rescaled
Raises:
ValueError: If the image is not in [0, 1]
Returns:
np.array(np.float32): array in [-1, +1]
"""
if img.min() >= 0 and img.max() <= 1:
return (img.astype(np.float32) - 0.5) * 2
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")
# No need to do any timing in this, since it's just for the HF Space
class ClimateGAN:
def __init__(self, model_path, dev_mode=False) -> None:
"""
A wrapper for the ClimateGAN model that you can use to generate
events from images or folders containing images.
Args:
model_path (Union[str, Path]): Where to load the Masker from
"""
torch.set_grad_enabled(False)
self.target_size = 640
self._stable_diffusion_is_setup = False
self.dev_mode = dev_mode
if self.dev_mode:
return
self.trainer = Trainer.resume_from_path(
model_path,
setup=True,
inference=True,
new_exp=None,
)
self.trainer.G.half()
def _setup_stable_diffusion(self):
"""
Sets up the stable diffusion pipeline for in-painting.
Make sure you have accepted the license on the model's card
https://huggingface.co/CompVis/stable-diffusion-v1-4
"""
if self.dev_mode:
return
try:
self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
revision="fp16",
torch_dtype=torch.float16,
safety_checker=None,
use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
).to(self.trainer.device)
self._stable_diffusion_is_setup = True
except Exception as e:
print(
"\nCould not load stable diffusion model. "
+ "Please make sure you have accepted the license on the model's"
+ " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n"
)
raise e
def _preprocess_image(self, img):
# rgba to rgb
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)
# to args.target_size
data = resize_and_crop(data, self.target_size)
# resize() produces [0, 1] images, rescale to [-1, 1]
data = to_m1_p1(data)
return data
# Does all three inferences at the moment.
def infer_single(
self,
orig_image,
painter="both",
prompt="An HD picture of a street with dirty water after a heavy flood",
concats=[
"input",
"masked_input",
"climategan_flood",
"stable_flood",
"stable_copy_flood",
],
):
"""
Infers the image with the ClimateGAN model.
Importantly (and unlike self.infer_preprocessed_batch), the image is
pre-processed by self._preprocess_image before going through the networks.
Output dict contains the following keys:
- "input": The input image
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
- "masked_input": The input image with the mask applied
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
on the masked input (only if "painter" is "climategan" or "both").
- "stable_flood": The flooded image in-painted by the stable diffusion model
from the mask and the input image (only if "painter" is "stable_diffusion"
or "both").
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
model with its original context pasted back in:
y = m * flooded + (1-m) * input
(only if "painter" is "stable_diffusion" or "both").
Args:
orig_image (Union[str, np.array]): image to infer on. Can be a path to
an image which will be read.
painter (str, optional): Which painter to use: "climategan",
"stable_diffusion" or "both". Defaults to "both".
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "An HD picture of a street with dirty water after a heavy flood".
concats (list, optional): List of keys in `output` to concatenate together
in a new `{original_stem}_concat` image written. Defaults to:
["input", "masked_input", "climategan_flood", "stable_flood",
"stable_copy_flood"].
Returns:
dict: a dictionary containing the output images {k: HxWxC}. C is omitted
for masks (HxW).
"""
if self.dev_mode:
return {
"input": np.random.randint(0, 255, (640, 640, 3)),
"mask": np.random.randint(0, 255, (640, 640)),
"masked_input": np.random.randint(0, 255, (640, 640, 3)),
"climategan_flood": np.random.randint(0, 255, (640, 640, 3)),
"stable_flood": np.random.randint(0, 255, (640, 640, 3)),
"stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)),
"concat": np.random.randint(0, 255, (640, 640 * 5, 3)),
"smog": np.random.randint(0, 255, (640, 640, 3)),
"wildfire": np.random.randint(0, 255, (640, 640, 3)),
}
image_array = (
np.array(Image.open(orig_image))
if isinstance(orig_image, str)
else orig_image
)
image = self._preprocess_image(image_array)
output_dict = self.infer_preprocessed_batch(
image[None, ...], painter, prompt, concats
)
return {k: v[0] for k, v in output_dict.items()}
def infer_preprocessed_batch(
self,
images,
painter="both",
prompt="An HD picture of a street with dirty water after a heavy flood",
concats=[
"input",
"masked_input",
"climategan_flood",
"stable_flood",
"stable_copy_flood",
],
):
"""
Infers ClimateGAN predictions on a batch of preprocessed images.
It assumes that each image in the batch has been preprocessed with
self._preprocess_image().
Output dict contains the following keys:
- "input": The input image
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
- "masked_input": The input image with the mask applied
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
on the masked input (only if "painter" is "climategan" or "both").
- "stable_flood": The flooded image in-painted by the stable diffusion model
from the mask and the input image (only if "painter" is "stable_diffusion"
or "both").
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
model with its original context pasted back in:
y = m * flooded + (1-m) * input
(only if "painter" is "stable_diffusion" or "both").
Args:
images (np.array): A batch of input images BxHxWx3
painter (str, optional): Which painter to use: "climategan",
"stable_diffusion" or "both". Defaults to "both".
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "An HD picture of a street with dirty water after a heavy flood".
concats (list, optional): List of keys in `output` to concatenate together
in a new `{original_stem}_concat` image written. Defaults to:
["input", "masked_input", "climategan_flood", "stable_flood",
"stable_copy_flood"].
Returns:
dict: a dictionary containing the output images
"""
assert painter in [
"both",
"stable_diffusion",
"climategan",
], f"Unknown painter: {painter}"
ignore_event = set()
if painter == "climategan":
ignore_event.add("flood")
# Retrieve numpy events as a dict {event: array[BxHxWxC]}
outputs = self.trainer.infer_all(
images,
numpy=True,
bin_value=0.5,
half=True,
ignore_event=ignore_event,
return_masks=True,
)
outputs["input"] = uint8(images, True)
# from Bx1xHxW to BxHxWx1
outputs["masked_input"] = outputs["input"] * (
outputs["mask"].squeeze(1)[..., None] == 0
)
if painter in {"both", "climategan"}:
outputs["climategan_flood"] = outputs.pop("flood")
else:
del outputs["flood"]
if painter != "climategan":
if not self._stable_diffusion_is_setup:
print("Setting up stable diffusion in-painting pipeline")
self._setup_stable_diffusion()
mask = outputs["mask"].squeeze(1)
input_images = (
torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
)
input_mask = torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
floods = self.sdip_pipeline(
prompt=[prompt] * images.shape[0],
image=input_images,
mask_image=input_mask,
height=640,
width=640,
num_inference_steps=50,
)
bin_mask = mask[..., None] > 0
flood = np.stack([np.array(i) for i in floods.images])
copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask)
outputs["stable_flood"] = flood
outputs["stable_copy_flood"] = copy_flood
if concats:
outputs["concat"] = concat_events(outputs, concats, axis=2)
return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()}
def infer_folder(
self,
folder_path,
painter="both",
prompt="An HD picture of a street with dirty water after a heavy flood",
batch_size=4,
concats=[
"input",
"masked_input",
"climategan_flood",
"stable_flood",
"stable_copy_flood",
],
write=True,
overwrite=False,
):
"""
Infers the images in a folder with the ClimateGAN model, batching images for
inference according to the batch_size.
Images must end in .jpg, .jpeg or .png (not case-sensitive).
Images must not contain the separator ("---") in their name.
Images will be written to disk in the same folder as the input images, with
a name that depends on its data, potentially the prompt and a random
identifier in case multiple inferences are run in the folder.
Output dict contains the following keys:
- "input": The input image
- "mask": The mask used to generate the flood (from ClimateGAN's Masker)
- "masked_input": The input image with the mask applied
- "climategan_flood": The flooded image generated by ClimateGAN's Painter
on the masked input (only if "painter" is "climategan" or "both").
- "stable_flood": The flooded image in-painted by the stable diffusion model
from the mask and the input image (only if "painter" is "stable_diffusion"
or "both").
- "stable_copy_flood": The flooded image in-painted by the stable diffusion
model with its original context pasted back in:
y = m * flooded + (1-m) * input
(only if "painter" is "stable_diffusion" or "both").
Args:
folder_path (Union[str, Path]): Where to read images from.
painter (str, optional): Which painter to use: "climategan",
"stable_diffusion" or "both". Defaults to "both".
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "An HD picture of a street with dirty water after a heavy flood".
batch_size (int, optional): Size of inference batches. Defaults to 4.
concats (list, optional): List of keys in `output` to concatenate together
in a new `{original_stem}_concat` image written. Defaults to:
["input", "masked_input", "climategan_flood", "stable_flood",
"stable_copy_flood"].
write (bool, optional): Whether or not to write the outputs to the input
folder.Defaults to True.
overwrite (Union[bool, str], optional): Whether to overwrite the images or
not. If a string is provided, it will be included in the name.
Defaults to False.
Returns:
dict: a dictionary containing the output images
"""
folder_path = Path(folder_path).expanduser().resolve()
assert folder_path.exists(), f"Folder {str(folder_path)} does not exist"
assert folder_path.is_dir(), f"{str(folder_path)} is not a directory"
im_paths = [
p
for p in folder_path.iterdir()
if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
]
assert im_paths, f"No images found in {str(folder_path)}"
ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths]
batches = [
np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size)
]
inferences = [
self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches
]
outputs = {
k: [i for e in inferences for i in e[k]] for k in inferences[0].keys()
}
if write:
self.write(outputs, im_paths, painter, overwrite, prompt)
return outputs
def write(
self,
outputs,
im_paths,
painter="both",
overwrite=False,
prompt="",
):
"""
Writes the outputs of the inference to disk, in the input folder.
Images will be named like:
f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}"
`painter_type` is either "climategan" or f"stable_diffusion_{prompt}"
Args:
outputs (_type_): The inference procedure's output dict.
im_paths (list[Path]): The list of input images paths.
painter (str, optional): Which painter was used. Defaults to "both".
overwrite (bool, optional): Whether to overwrite the images or not.
If a string is provided, it will be included in the name.
If False, a random identifier will be added to the name.
Defaults to False.
prompt (str, optional): The prompt used to guide the diffusion. Defaults
to "".
"""
prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower()
overwrite_prefix = ""
if not overwrite:
overwrite_prefix = str(uuid4())[:8]
print("Writing events with prefix", overwrite_prefix)
else:
if isinstance(overwrite, str):
overwrite_prefix = overwrite
print("Writing events with prefix", overwrite_prefix)
# for each image, for each event/data type
for i, im_path in enumerate(im_paths):
for event, ims in outputs.items():
painter_prefix = ""
if painter == "climategan" and event == "flood":
painter_prefix = "climategan"
elif (
painter in {"stable_diffusion", "both"} and event == "stable_flood"
):
painter_prefix = f"_stable_{prompt}"
elif painter == "both" and event == "climategan_flood":
painter_prefix = ""
im = ims[i]
im = Image.fromarray(uint8(im))
imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
im.save(im_path.parent / (imstem + im_path.suffix))