Spaces:
Runtime error
Runtime error
# based on https://huggingface.co/spaces/NimaBoscarino/climategan/blob/main/inferences.py # noqa: E501 | |
# thank you @NimaBoscarino | |
import torch | |
from skimage.color import rgba2rgb | |
from skimage.transform import resize | |
import numpy as np | |
from climategan.trainer import Trainer | |
def uint8(array): | |
""" | |
convert an array to np.uint8 (does not rescale or anything else than changing dtype) | |
Args: | |
array (np.array): array to modify | |
Returns: | |
np.array(np.uint8): converted array | |
""" | |
return array.astype(np.uint8) | |
def resize_and_crop(img, to=640): | |
""" | |
Resizes an image so that it keeps the aspect ratio and the smallest dimensions | |
is `to`, then crops this resized image in its center so that the output is `to x to` | |
without aspect ratio distortion | |
Args: | |
img (np.array): np.uint8 255 image | |
Returns: | |
np.array: [0, 1] np.float32 image | |
""" | |
# resize keeping aspect ratio: smallest dim is 640 | |
h, w = img.shape[:2] | |
if h < w: | |
size = (to, int(to * w / h)) | |
else: | |
size = (int(to * h / w), to) | |
r_img = resize(img, size, preserve_range=True, anti_aliasing=True) | |
r_img = uint8(r_img) | |
# crop in the center | |
H, W = r_img.shape[:2] | |
top = (H - to) // 2 | |
left = (W - to) // 2 | |
rc_img = r_img[top : top + to, left : left + to, :] | |
return rc_img / 255.0 | |
def to_m1_p1(img): | |
""" | |
rescales a [0, 1] image to [-1, +1] | |
Args: | |
img (np.array): float32 numpy array of an image in [0, 1] | |
i (int): Index of the image being rescaled | |
Raises: | |
ValueError: If the image is not in [0, 1] | |
Returns: | |
np.array(np.float32): array in [-1, +1] | |
""" | |
if img.min() >= 0 and img.max() <= 1: | |
return (img.astype(np.float32) - 0.5) * 2 | |
raise ValueError(f"Data range mismatch for image: ({img.min()}, {img.max()})") | |
# No need to do any timing in this, since it's just for the HF Space | |
class ClimateGAN: | |
def __init__(self, model_path) -> None: | |
torch.set_grad_enabled(False) | |
self.target_size = 640 | |
self.trainer = Trainer.resume_from_path( | |
model_path, | |
setup=True, | |
inference=True, | |
new_exp=None, | |
) | |
# Does all three inferences at the moment. | |
def inference(self, orig_image): | |
image = self._preprocess_image(orig_image) | |
# Retrieve numpy events as a dict {event: array[BxHxWxC]} | |
outputs = self.trainer.infer_all( | |
image, | |
numpy=True, | |
bin_value=0.5, | |
) | |
return ( | |
outputs["flood"].squeeze(), | |
outputs["wildfire"].squeeze(), | |
outputs["smog"].squeeze(), | |
) | |
def _preprocess_image(self, img): | |
# rgba to rgb | |
data = img if img.shape[-1] == 3 else uint8(rgba2rgb(img) * 255) | |
# to args.target_size | |
data = resize_and_crop(data, self.target_size) | |
# resize() produces [0, 1] images, rescale to [-1, 1] | |
data = to_m1_p1(data) | |
return data | |