import math
import numpy as np
from typing import Tuple
import torch
import torch.nn as nn
from torchvision.utils import make_grid
import cv2
from torchvision import transforms, models
from PIL import Image
import torchvision.transforms.functional as tf
# --------------------------------------------Metric tools-------------------------------------------- #
def lab_shift(x, invert=False):
x = x.float()
if invert:
x[:, 0, :, :] /= 2.55
x[:, 1, :, :] -= 128
x[:, 2, :, :] -= 128
x[:, 0, :, :] *= 2.55
x[:, 1, :, :] += 128
x[:, 2, :, :] += 128
return x
def calculate_psnr(img1, img2):
# img1 and img2 have range [0, 255]
img1 = img1.astype(np.float64)
img2 = img2.astype(np.float64)
mse = np.mean((img1 - img2)**2)
if mse == 0:
return float('inf')
return 20 * math.log10(255.0 / math.sqrt(mse))
def calculate_fpsnr(fmse):
return 10 * math.log10(255.0 / (fmse + 1e-8))
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1), bit=8):
Converts a torch Tensor into an image Numpy array
Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
norm = float(2**bit) - 1
# print('before', tensor[:,:,0].max(), tensor[:,:,0].min(), '\t', tensor[:,:,1].max(), tensor[:,:,1].min(), '\t', tensor[:,:,2].max(), tensor[:,:,2].min())
tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
# print('clamp ', tensor[:,:,0].max(), tensor[:,:,0].min(), '\t', tensor[:,:,1].max(), tensor[:,:,1].min(), '\t', tensor[:,:,2].max(), tensor[:,:,2].min())
tensor = (tensor - min_max[0]) / \
(min_max[1] - min_max[0]) # to range [0,1]
n_dim = tensor.dim()
if n_dim == 4:
n_img = len(tensor)
img_np = make_grid(tensor, nrow=int(
math.sqrt(n_img)), normalize=False).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 3:
img_np = tensor.numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
elif n_dim == 2:
img_np = tensor.numpy()
raise TypeError(
'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
if out_type == np.uint8:
# Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
img_np = (img_np * norm).round()
return img_np.astype(out_type)
def rgb_to_lab(image: torch.Tensor) -> torch.Tensor:
r"""Convert a RGB image to Lab.
.. image:: _static/img/rgb_to_lab.png
The input RGB image is assumed to be in the range of :math:`[0, 1]`. Lab
color is computed using the D65 illuminant and Observer 2.
image: RGB Image to be converted to Lab with shape :math:`(*, 3, H, W)`.
Lab version of the image with shape :math:`(*, 3, H, W)`.
The L channel values are in the range 0..100. a and b are in the range -128..127.
>>> input = torch.rand(2, 3, 4, 5)
>>> output = rgb_to_lab(input) # 2x3x4x5
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(
f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
# Convert from sRGB to Linear RGB
lin_rgb = rgb_to_linear_rgb(image)
xyz_im: torch.Tensor = rgb_to_xyz(lin_rgb)
# normalize for D65 white point
xyz_ref_white = torch.tensor(
[0.95047, 1.0, 1.08883], device=xyz_im.device, dtype=xyz_im.dtype)[..., :, None, None]
xyz_normalized = torch.div(xyz_im, xyz_ref_white)
threshold = 0.008856
power = torch.pow(xyz_normalized.clamp(min=threshold), 1 / 3.0)
scale = 7.787 * xyz_normalized + 4.0 / 29.0
xyz_int = torch.where(xyz_normalized > threshold, power, scale)
x: torch.Tensor = xyz_int[..., 0, :, :]
y: torch.Tensor = xyz_int[..., 1, :, :]
z: torch.Tensor = xyz_int[..., 2, :, :]
L: torch.Tensor = (116.0 * y) - 16.0
a: torch.Tensor = 500.0 * (x - y)
_b: torch.Tensor = 200.0 * (y - z)
out: torch.Tensor = torch.stack([L, a, _b], dim=-3)
return out
def lab_to_rgb(image: torch.Tensor, clip: bool = True) -> torch.Tensor:
r"""Convert a Lab image to RGB.
The L channel is assumed to be in the range of :math:`[0, 100]`.
a and b channels are in the range of :math:`[-128, 127]`.
image: Lab image to be converted to RGB with shape :math:`(*, 3, H, W)`.
clip: Whether to apply clipping to insure output RGB values in range :math:`[0, 1]`.
Lab version of the image with shape :math:`(*, 3, H, W)`.
The output RGB image are in the range of :math:`[0, 1]`.
>>> input = torch.rand(2, 3, 4, 5)
>>> output = lab_to_rgb(input) # 2x3x4x5
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(
f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
L: torch.Tensor = image[..., 0, :, :]
a: torch.Tensor = image[..., 1, :, :]
_b: torch.Tensor = image[..., 2, :, :]
fy = (L + 16.0) / 116.0
fx = (a / 500.0) + fy
fz = fy - (_b / 200.0)
# if color data out of range: Z < 0
fz = fz.clamp(min=0.0)
fxyz = torch.stack([fx, fy, fz], dim=-3)
# Convert from Lab to XYZ
power = torch.pow(fxyz, 3.0)
scale = (fxyz - 4.0 / 29.0) / 7.787
xyz = torch.where(fxyz > 0.2068966, power, scale)
# For D65 white point
xyz_ref_white = torch.tensor(
[0.95047, 1.0, 1.08883], device=xyz.device, dtype=xyz.dtype)[..., :, None, None]
xyz_im = xyz * xyz_ref_white
rgbs_im: torch.Tensor = xyz_to_rgb(xyz_im)
# rgbs_im = torch.where(rgbs_im < 0, torch.zeros_like(rgbs_im), rgbs_im)
# Convert from RGB Linear to sRGB
rgb_im = linear_rgb_to_rgb(rgbs_im)
# Clip to 0,1
if clip:
rgb_im = torch.clamp(rgb_im, min=0.0, max=1.0)
return rgb_im
def rgb_to_xyz(image: torch.Tensor) -> torch.Tensor:
r"""Convert a RGB image to XYZ.
.. image:: _static/img/rgb_to_xyz.png
image: RGB Image to be converted to XYZ with shape :math:`(*, 3, H, W)`.
XYZ version of the image with shape :math:`(*, 3, H, W)`.
>>> input = torch.rand(2, 3, 4, 5)
>>> output = rgb_to_xyz(input) # 2x3x4x5
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(
f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
r: torch.Tensor = image[..., 0, :, :]
g: torch.Tensor = image[..., 1, :, :]
b: torch.Tensor = image[..., 2, :, :]
x: torch.Tensor = 0.412453 * r + 0.357580 * g + 0.180423 * b
y: torch.Tensor = 0.212671 * r + 0.715160 * g + 0.072169 * b
z: torch.Tensor = 0.019334 * r + 0.119193 * g + 0.950227 * b
out: torch.Tensor = torch.stack([x, y, z], -3)
return out
def xyz_to_rgb(image: torch.Tensor) -> torch.Tensor:
r"""Convert a XYZ image to RGB.
image: XYZ Image to be converted to RGB with shape :math:`(*, 3, H, W)`.
RGB version of the image with shape :math:`(*, 3, H, W)`.
>>> input = torch.rand(2, 3, 4, 5)
>>> output = xyz_to_rgb(input) # 2x3x4x5
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(
f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
x: torch.Tensor = image[..., 0, :, :]
y: torch.Tensor = image[..., 1, :, :]
z: torch.Tensor = image[..., 2, :, :]
r: torch.Tensor = 3.2404813432005266 * x + - \
1.5371515162713185 * y + -0.4985363261688878 * z
g: torch.Tensor = -0.9692549499965682 * x + \
1.8759900014898907 * y + 0.0415559265582928 * z
b: torch.Tensor = 0.0556466391351772 * x + - \
0.2040413383665112 * y + 1.0573110696453443 * z
out: torch.Tensor = torch.stack([r, g, b], dim=-3)
return out
def rgb_to_linear_rgb(image: torch.Tensor) -> torch.Tensor:
r"""Convert an sRGB image to linear RGB. Used in colorspace conversions.
.. image:: _static/img/rgb_to_linear_rgb.png
image: sRGB Image to be converted to linear RGB of shape :math:`(*,3,H,W)`.
linear RGB version of the image with shape of :math:`(*,3,H,W)`.
>>> input = torch.rand(2, 3, 4, 5)
>>> output = rgb_to_linear_rgb(input) # 2x3x4x5
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(
f"Input size must have a shape of (*, 3, H, W).Got {image.shape}")
lin_rgb: torch.Tensor = torch.where(image > 0.04045, torch.pow(
((image + 0.055) / 1.055), 2.4), image / 12.92)
return lin_rgb
def linear_rgb_to_rgb(image: torch.Tensor) -> torch.Tensor:
r"""Convert a linear RGB image to sRGB. Used in colorspace conversions.
image: linear RGB Image to be converted to sRGB of shape :math:`(*,3,H,W)`.
sRGB version of the image with shape of shape :math:`(*,3,H,W)`.
>>> input = torch.rand(2, 3, 4, 5)
>>> output = linear_rgb_to_rgb(input) # 2x3x4x5
if not isinstance(image, torch.Tensor):
raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
if len(image.shape) < 3 or image.shape[-3] != 3:
raise ValueError(
f"Input size must have a shape of (*, 3, H, W).Got {image.shape}")
threshold = 0.0031308
rgb: torch.Tensor = torch.where(
image > threshold, 1.055 *
torch.pow(image.clamp(min=threshold), 1 / 2.4) - 0.055, 12.92 * image
return rgb
# --------------------------------------------Inference tools-------------------------------------------- #
def inference_img(model, img, device='cpu'):
h, w, _ = img.shape
# print(img.shape)
if h % 8 != 0 or w % 8 != 0:
img = cv2.copyMakeBorder(img, 8-h % 8, 0, 8-w %
# print(img.shape)
tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device)
input_t = tensor_img
input_t = input_t/255.0
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
input_t = normalize(input_t)
input_t = input_t.unsqueeze(0).float()
with torch.no_grad():
out = model(input_t)
# print("out",out.shape)
result = out[0][:, -h:, -w:].cpu().numpy()
# print(result.shape)
return result[0]
def log(msg, lvl='info'):
if lvl == 'info':
if lvl == 'error':
print(f"!!! Exception: {msg} !!!")
def harmonize(comp, mask, model):
log("Inference started")
if comp is None or mask is None:
log("Empty source")
return np.zeros((16, 16, 3))
comp = comp.convert('RGB')
mask = mask.convert('1')
in_shape = comp.size[::-1]
comp = tf.resize(comp, [model.image_size, model.image_size])
mask = tf.resize(mask, [model.image_size, model.image_size])
compt = tf.to_tensor(comp)
maskt = tf.to_tensor(mask)
res = model.harmonize(compt, maskt)
res = tf.resize(res, in_shape)
log("Inference finished")
return np.uint8((res*255)[0].permute(1, 2, 0).numpy())
def extract_matte(img, back, model):
mask, fg = model.extract(img)
fg_pil = Image.fromarray(np.uint8(fg))
composite = fg + (1 - mask[:, :, None]) * \
composite_pil = Image.fromarray(np.uint8(composite))
return [composite_pil, mask, fg_pil]
def css(height=3, scale=2):
return f".output_image {{height: {height}rem !important; width: {scale}rem !important;}}"