File size: 19,363 Bytes
cd31093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae61fe0
cd31093
 
 
 
 
 
 
 
 
ae61fe0
 
 
 
cd31093
 
 
 
 
 
 
 
 
 
 
 
 
 
ae61fe0
 
 
cd31093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae61fe0
 
 
 
 
 
 
 
 
 
 
 
cd31093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501
# thank you @NimaBoscarino

import re
from pathlib import Path
from uuid import uuid4

import numpy as np
import torch
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
from skimage.color import rgba2rgb
from skimage.transform import resize

from climategan.trainer import Trainer


def concat_events(output_dict, events, i=None, axis=1):
    """
    Concatenates the `i`th data in `output_dict` according to the keys listed
    in `events` on dimension `axis`.

    Args:
        output_dict (dict[Union[list[np.array], np.array]]): A dictionary mapping
            events to their corresponding data :
            {k: [HxWxC]} (for i != None) or {k: BxHxWxC}.
        events (list[str]): output_dict's keys to concatenate.
        axis (int, optional): Concatenation axis. Defaults to 1.
    """
    cs = [e for e in events if e in output_dict]
    if i is not None:
        return uint8(np.concatenate([output_dict[c][i] for c in cs], axis=axis))
    return uint8(np.concatenate([output_dict[c] for c in cs], axis=axis))


def clear(folder):
    """
    Deletes all the images without the inference separator "---" in their name.

    Args:
        folder (Union[str, Path]): The folder to clear.
    """
    for i in list(Path(folder).iterdir()):
        if i.is_file() and "---" in i.stem:
            i.unlink()


def uint8(array, rescale=False):
    """
    convert an array to np.uint8 (does not rescale or anything else than changing dtype)
    Args:
        array (np.array): array to modify
    Returns:
        np.array(np.uint8): converted array
    """
    if rescale:
        if array.min() < 0:
            if array.min() >= -1 and array.max() <= 1:
                array = (array + 1) / 2
            else:
                raise ValueError(
                    f"Data range mismatch for image: ({array.min()}, {array.max()})"
                )
        if array.max() <= 1:
            array = array * 255
    return array.astype(np.uint8)


def resize_and_crop(img, to=640):
    """
    Resizes an image so that it keeps the aspect ratio and the smallest dimensions
    is `to`, then crops this resized image in its center so that the output is `to x to`
    without aspect ratio distortion
    Args:
        img (np.array): np.uint8 255 image
    Returns:
        np.array: [0, 1] np.float32 image
    """
    # resize keeping aspect ratio: smallest dim is 640
    h, w = img.shape[:2]
    if h < w:
        size = (to, int(to * w / h))
    else:
        size = (int(to * h / w), to)

    r_img = resize(img, size, preserve_range=True, anti_aliasing=True)
    r_img = uint8(r_img)

    # crop in the center
    H, W = r_img.shape[:2]

    top = (H - to) // 2
    left = (W - to) // 2

    rc_img = r_img[top : top + to, left : left + to, :]

    return rc_img / 255.0


def to_m1_p1(img):
    """
    rescales a [0, 1] image to [-1, +1]
    Args:
        img (np.array): float32 numpy array of an image in [0, 1]
        i (int): Index of the image being rescaled
    Raises:
        ValueError: If the image is not in [0, 1]
    Returns:
        np.array(np.float32): array in [-1, +1]
    """
    if img.min() >= 0 and img.max() <= 1:
        return (img.astype(np.float32) - 0.5) * 2
    raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})")


# No need to do any timing in this, since it's just for the HF Space
class ClimateGAN:
    def __init__(self, model_path, dev_mode=False) -> None:
        """
        A wrapper for the ClimateGAN model that you can use to generate
        events from images or folders containing images.

        Args:
            model_path (Union[str, Path]): Where to load the Masker from
        """
        torch.set_grad_enabled(False)
        self.target_size = 640
        self._stable_diffusion_is_setup = False
        self.dev_mode = dev_mode
        if self.dev_mode:
            return
        self.trainer = Trainer.resume_from_path(
            model_path,
            setup=True,
            inference=True,
            new_exp=None,
        )
        self.trainer.G.half()

    def _setup_stable_diffusion(self):
        """
        Sets up the stable diffusion pipeline for in-painting.
        Make sure you have accepted the license on the model's card
        https://huggingface.co/CompVis/stable-diffusion-v1-4
        """
        if self.dev_mode:
            return

        try:
            self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
                "runwayml/stable-diffusion-inpainting",
                revision="fp16",
                torch_dtype=torch.float16,
                safety_checker=None,
            ).to(self.trainer.device)
            self._stable_diffusion_is_setup = True
        except Exception as e:
            print(
                "\nCould not load stable diffusion model. "
                + "Please make sure you have accepted the license on the model's"
                + " card https://huggingface.co/CompVis/stable-diffusion-v1-4\n"
            )
            raise e

    def _preprocess_image(self, img):
        # rgba to rgb
        data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255)

        # to args.target_size
        data = resize_and_crop(data, self.target_size)

        # resize() produces [0, 1] images, rescale to [-1, 1]
        data = to_m1_p1(data)
        return data

    # Does all three inferences at the moment.
    def infer_single(
        self,
        orig_image,
        painter="both",
        prompt="An HD picture of a street with dirty water after a heavy flood",
        concats=[
            "input",
            "masked_input",
            "climategan_flood",
            "stable_flood",
            "stable_copy_flood",
        ],
    ):
        """
        Infers the image with the ClimateGAN model.
        Importantly (and unlike self.infer_preprocessed_batch), the image is
        pre-processed by self._preprocess_image before going through the networks.

        Output dict contains the following keys:
        - "input": The input image
        - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
        - "masked_input": The input image with the mask applied
        - "climategan_flood": The flooded image generated by ClimateGAN's Painter
            on the masked input (only if "painter" is "climategan" or "both").
        - "stable_flood": The flooded image in-painted by the stable diffusion model
            from the mask and the input image (only if "painter" is "stable_diffusion"
            or "both").
        - "stable_copy_flood": The flooded image in-painted by the stable diffusion
            model with its original context pasted back in:
            y = m * flooded + (1-m) * input
            (only if "painter" is "stable_diffusion" or "both").

        Args:
            orig_image (Union[str, np.array]): image to infer on. Can be a path to
                an image which will be read.
            painter (str, optional): Which painter to use: "climategan",
                "stable_diffusion" or "both". Defaults to "both".
            prompt (str, optional): The prompt used to guide the diffusion. Defaults
                to "An HD picture of a street with dirty water after a heavy flood".
            concats (list, optional): List of keys in `output` to concatenate together
                in a new `{original_stem}_concat` image written. Defaults to:
                ["input", "masked_input", "climategan_flood", "stable_flood",
                "stable_copy_flood"].

        Returns:
            dict: a dictionary containing the output images {k: HxWxC}. C is omitted
                for masks (HxW).
        """
        if self.dev_mode:
            return {
                "input": np.random.randint(0, 255, (640, 640, 3)),
                "mask": np.random.randint(0, 255, (640, 640)),
                "masked_input": np.random.randint(0, 255, (640, 640, 3)),
                "climategan_flood": np.random.randint(0, 255, (640, 640, 3)),
                "stable_flood": np.random.randint(0, 255, (640, 640, 3)),
                "stable_copy_flood": np.random.randint(0, 255, (640, 640, 3)),
                "concat": np.random.randint(0, 255, (640, 640 * 5, 3)),
                "smog": np.random.randint(0, 255, (640, 640, 3)),
                "wildfire": np.random.randint(0, 255, (640, 640, 3)),
            }
        image_array = (
            np.array(Image.open(orig_image))
            if isinstance(orig_image, str)
            else orig_image
        )
        image = self._preprocess_image(image_array)
        output_dict = self.infer_preprocessed_batch(
            image[None, ...], painter, prompt, concats
        )
        return {k: v[0] for k, v in output_dict.items()}

    def infer_preprocessed_batch(
        self,
        images,
        painter="both",
        prompt="An HD picture of a street with dirty water after a heavy flood",
        concats=[
            "input",
            "masked_input",
            "climategan_flood",
            "stable_flood",
            "stable_copy_flood",
        ],
    ):
        """
        Infers ClimateGAN predictions on a batch of preprocessed images.
        It assumes that each image in the batch has been preprocessed with
        self._preprocess_image().

        Output dict contains the following keys:
        - "input": The input image
        - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
        - "masked_input": The input image with the mask applied
        - "climategan_flood": The flooded image generated by ClimateGAN's Painter
            on the masked input (only if "painter" is "climategan" or "both").
        - "stable_flood": The flooded image in-painted by the stable diffusion model
            from the mask and the input image (only if "painter" is "stable_diffusion"
            or "both").
        - "stable_copy_flood": The flooded image in-painted by the stable diffusion
            model with its original context pasted back in:
            y = m * flooded + (1-m) * input
            (only if "painter" is "stable_diffusion" or "both").

        Args:
            images (np.array): A batch of input images BxHxWx3
            painter (str, optional): Which painter to use: "climategan",
                "stable_diffusion" or "both". Defaults to "both".
            prompt (str, optional): The prompt used to guide the diffusion. Defaults
                to "An HD picture of a street with dirty water after a heavy flood".
            concats (list, optional): List of keys in `output` to concatenate together
                in a new `{original_stem}_concat` image written. Defaults to:
                ["input", "masked_input", "climategan_flood", "stable_flood",
                "stable_copy_flood"].

        Returns:
            dict: a dictionary containing the output images
        """
        assert painter in [
            "both",
            "stable_diffusion",
            "climategan",
        ], f"Unknown painter: {painter}"

        ignore_event = set()
        if painter == "climategan":
            ignore_event.add("flood")

        # Retrieve numpy events as a dict {event: array[BxHxWxC]}
        outputs = self.trainer.infer_all(
            images,
            numpy=True,
            bin_value=0.5,
            half=True,
            ignore_event=ignore_event,
            return_masks=True,
        )

        outputs["input"] = uint8(images, True)
        # from Bx1xHxW to BxHxWx1
        outputs["masked_input"] = outputs["input"] * (
            outputs["mask"].squeeze(1)[..., None] == 0
        )

        if painter in {"both", "climategan"}:
            outputs["climategan_flood"] = outputs.pop("flood")
        else:
            del outputs["flood"]

        if painter != "climategan":
            if not self._stable_diffusion_is_setup:
                print("Setting up stable diffusion in-painting pipeline")
                self._setup_stable_diffusion()

            mask = outputs["mask"].squeeze(1)
            input_images = (
                torch.tensor(images).permute(0, 3, 1, 2).to(self.trainer.device)
            )
            input_mask = torch.tensor(mask[:, None, ...] > 0).to(self.trainer.device)
            floods = self.sdip_pipeline(
                prompt=[prompt] * images.shape[0],
                image=input_images,
                mask_image=input_mask,
                height=640,
                width=640,
                num_inference_steps=50,
            )

            bin_mask = mask[..., None] > 0
            flood = np.stack([np.array(i) for i in floods.images])
            copy_flood = flood * bin_mask + uint8(images, True) * (1 - bin_mask)
            outputs["stable_flood"] = flood
            outputs["stable_copy_flood"] = copy_flood

        if concats:
            outputs["concat"] = concat_events(outputs, concats, axis=2)

        return {k: v.squeeze(1) if v.shape[1] == 1 else v for k, v in outputs.items()}

    def infer_folder(
        self,
        folder_path,
        painter="both",
        prompt="An HD picture of a street with dirty water after a heavy flood",
        batch_size=4,
        concats=[
            "input",
            "masked_input",
            "climategan_flood",
            "stable_flood",
            "stable_copy_flood",
        ],
        write=True,
        overwrite=False,
    ):
        """
        Infers the images in a folder with the ClimateGAN model, batching images for
        inference according to the batch_size.

        Images must end in .jpg, .jpeg or .png (not case-sensitive).
        Images must not contain the separator ("---") in their name.

        Images will be written to disk in the same folder as the input images, with
        a name that depends on its data, potentially the prompt and a random
        identifier in case multiple inferences are run in the folder.

        Output dict contains the following keys:
        - "input": The input image
        - "mask": The mask used to generate the flood (from ClimateGAN's Masker)
        - "masked_input": The input image with the mask applied
        - "climategan_flood": The flooded image generated by ClimateGAN's Painter
            on the masked input (only if "painter" is "climategan" or "both").
        - "stable_flood": The flooded image in-painted by the stable diffusion model
            from the mask and the input image (only if "painter" is "stable_diffusion"
            or "both").
        - "stable_copy_flood": The flooded image in-painted by the stable diffusion
            model with its original context pasted back in:
            y = m * flooded + (1-m) * input
            (only if "painter" is "stable_diffusion" or "both").

        Args:
            folder_path (Union[str, Path]): Where to read images from.
            painter (str, optional): Which painter to use: "climategan",
                "stable_diffusion" or "both". Defaults to "both".
            prompt (str, optional): The prompt used to guide the diffusion. Defaults
                to "An HD picture of a street with dirty water after a heavy flood".
            batch_size (int, optional): Size of inference batches. Defaults to 4.
            concats (list, optional): List of keys in `output` to concatenate together
                in a new `{original_stem}_concat` image written. Defaults to:
                ["input", "masked_input", "climategan_flood", "stable_flood",
                "stable_copy_flood"].
            write (bool, optional): Whether or not to write the outputs to the input
                folder.Defaults to True.
            overwrite (Union[bool, str], optional): Whether to overwrite the images or
                not. If a string is provided, it will be included in the name.
                Defaults to False.

        Returns:
            dict: a dictionary containing the output images
        """
        folder_path = Path(folder_path).expanduser().resolve()
        assert folder_path.exists(), f"Folder {str(folder_path)} does not exist"
        assert folder_path.is_dir(), f"{str(folder_path)} is not a directory"
        im_paths = [
            p
            for p in folder_path.iterdir()
            if p.suffix.lower() in [".jpg", ".png", ".jpeg"] and "---" not in p.name
        ]
        assert im_paths, f"No images found in {str(folder_path)}"
        ims = [self._preprocess_image(np.array(Image.open(p))) for p in im_paths]
        batches = [
            np.stack(ims[i : i + batch_size]) for i in range(0, len(ims), batch_size)
        ]
        inferences = [
            self.infer_preprocessed_batch(b, painter, prompt, concats) for b in batches
        ]

        outputs = {
            k: [i for e in inferences for i in e[k]] for k in inferences[0].keys()
        }

        if write:
            self.write(outputs, im_paths, painter, overwrite, prompt)

        return outputs

    def write(
        self,
        outputs,
        im_paths,
        painter="both",
        overwrite=False,
        prompt="",
    ):
        """
        Writes the outputs of the inference to disk, in the input folder.

        Images will be named like:
        f"{original_stem}---{overwrite_prefix}_{painter_type}_{output_type}.{suffix}"
        `painter_type` is either "climategan" or f"stable_diffusion_{prompt}"

        Args:
            outputs (_type_): The inference procedure's output dict.
            im_paths (list[Path]): The list of input images paths.
            painter (str, optional): Which painter was used. Defaults to "both".
            overwrite (bool, optional): Whether to overwrite the images or not.
                If a string is provided, it will be included in the name.
                If False, a random identifier will be added to the name.
                Defaults to False.
            prompt (str, optional): The prompt used to guide the diffusion. Defaults
                to "".
        """
        prompt = re.sub("[^0-9a-zA-Z]+", "", prompt).lower()
        overwrite_prefix = ""
        if not overwrite:
            overwrite_prefix = str(uuid4())[:8]
            print("Writing events with prefix", overwrite_prefix)
        else:
            if isinstance(overwrite, str):
                overwrite_prefix = overwrite
                print("Writing events with prefix", overwrite_prefix)

        # for each image, for each event/data type
        for i, im_path in enumerate(im_paths):
            for event, ims in outputs.items():
                painter_prefix = ""
                if painter == "climategan" and event == "flood":
                    painter_prefix = "climategan"
                elif (
                    painter in {"stable_diffusion", "both"} and event == "stable_flood"
                ):
                    painter_prefix = f"_stable_{prompt}"
                elif painter == "both" and event == "climategan_flood":
                    painter_prefix = ""

                im = ims[i]
                im = Image.fromarray(uint8(im))
                imstem = f"{im_path.stem}---{overwrite_prefix}{painter_prefix}_{event}"
                im.save(im_path.parent / (imstem + im_path.suffix))