Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
import random | |
import kornia | |
from torchvision.transforms.functional import adjust_brightness, adjust_contrast | |
from climategan.tutils import normalize, retrieve_sky_mask | |
try: | |
from kornia.filters import filter2d | |
except ImportError: | |
from kornia.filters import filter2D as filter2d | |
def increase_sky_mask(mask, p_w=0, p_h=0): | |
""" | |
Increases sky mask in width and height by a given pourcentage | |
(Purpose: when applying Gaussian blur, there are no artifacts of blue sky behind) | |
Args: | |
sky_mask (torch.Tensor): Sky mask of shape (H,W) | |
p_w (float): Percentage of mask width by which to increase | |
the width of the sky region | |
p_h (float): Percentage of mask height by which to increase | |
the height of the sky region | |
Returns: | |
torch.Tensor: Sky mask increased given p_w and p_h | |
""" | |
if p_h <= 0 and p_w <= 0: | |
return mask | |
n_lines = int(p_h * mask.shape[-2]) | |
n_cols = int(p_w * mask.shape[-1]) | |
temp_mask = mask.clone().detach() | |
for i in range(1, n_cols): | |
temp_mask[:, :, :, i::] += mask[:, :, :, 0:-i] | |
temp_mask[:, :, :, 0:-i] += mask[:, :, :, i::] | |
new_mask = temp_mask.clone().detach() | |
for i in range(1, n_lines): | |
new_mask[:, :, i::, :] += temp_mask[:, :, 0:-i, :] | |
new_mask[:, :, 0:-i, :] += temp_mask[:, :, i::, :] | |
new_mask[new_mask >= 1] = 1 | |
return new_mask | |
def paste_filter(x, filter_, mask): | |
""" | |
Pastes a filter over an image given a mask | |
Where the mask is 1, the filter is copied as is. | |
Where the mask is 0, the current value is preserved. | |
Intermediate values will mix the two images together. | |
Args: | |
x (torch.Tensor): Input tensor, range must be [0, 255] | |
filer_ (torch.Tensor): Filter, range must be [0, 255] | |
mask (torch.Tensor): Mask, range must be [0, 1] | |
Returns: | |
torch.Tensor: New tensor with filter pasted on it | |
""" | |
assert len(x.shape) == len(filter_.shape) == len(mask.shape) | |
x = filter_ * mask + x * (1 - mask) | |
return x | |
def add_fire(x, seg_preds, fire_opts): | |
""" | |
Transforms input tensor given wildfires event | |
Args: | |
x (torch.Tensor): Input tensor | |
seg_preds (torch.Tensor): Semantic segmentation predictions for input tensor | |
filter_color (tuple): (r,g,b) tuple for the color of the sky | |
blur_radius (float): radius of the Gaussian blur that smooths | |
the transition between sky and foreground | |
Returns: | |
torch.Tensor: Wildfire version of input tensor | |
""" | |
wildfire_tens = normalize(x, 0, 255) | |
# Warm the image | |
wildfire_tens[:, 2, :, :] -= 20 | |
wildfire_tens[:, 1, :, :] -= 10 | |
wildfire_tens[:, 0, :, :] += 40 | |
wildfire_tens.clamp_(0, 255) | |
wildfire_tens = wildfire_tens.to(torch.uint8) | |
# Darken the picture and increase contrast | |
wildfire_tens = adjust_contrast(wildfire_tens, contrast_factor=1.5) | |
wildfire_tens = adjust_brightness(wildfire_tens, brightness_factor=0.73) | |
sky_mask = retrieve_sky_mask(seg_preds).unsqueeze(1) | |
if fire_opts.get("crop_bottom_sky_mask"): | |
i = 2 * sky_mask.shape[-2] // 3 | |
sky_mask[..., i:, :] = 0 | |
sky_mask = F.interpolate( | |
sky_mask.to(torch.float), | |
(wildfire_tens.shape[-2], wildfire_tens.shape[-1]), | |
) | |
sky_mask = increase_sky_mask(sky_mask, 0.18, 0.18) | |
kernel_size = (fire_opts.get("kernel_size", 301), fire_opts.get("kernel_size", 301)) | |
sigma = (fire_opts.get("kernel_sigma", 150.5), fire_opts.get("kernel_sigma", 150.5)) | |
border_type = "reflect" | |
kernel = torch.unsqueeze( | |
kornia.filters.kernels.get_gaussian_kernel2d(kernel_size, sigma), dim=0 | |
).to(x.device) | |
sky_mask = filter2d(sky_mask, kernel, border_type) | |
filter_ = torch.ones(wildfire_tens.shape, device=x.device) | |
filter_[:, 0, :, :] = 255 | |
filter_[:, 1, :, :] = random.randint(100, 150) | |
filter_[:, 2, :, :] = 0 | |
wildfire_tens = paste_tensor(wildfire_tens, filter_, sky_mask, 200) | |
wildfire_tens = adjust_brightness(wildfire_tens.to(torch.uint8), 0.8) | |
wildfire_tens = wildfire_tens.to(torch.float) | |
# dummy pixels to fool scaling and preserve range | |
wildfire_tens[:, :, 0, 0] = 255.0 | |
wildfire_tens[:, :, -1, -1] = 0.0 | |
return wildfire_tens | |
def paste_tensor(source, filter_, mask, transparency): | |
mask = transparency / 255.0 * mask | |
new = mask * filter_ + (1.0 - mask) * source | |
return new | |