import math import numpy as np from typing import Tuple import torch import torch.nn as nn from torchvision.utils import make_grid import cv2 from torchvision import transforms, models def log(msg, lvl='info'): if lvl == 'info': print(f"***********{msg}****************") if lvl == 'error': print(f"!!! Exception: {msg} !!!") def lab_shift(x, invert=False): x = x.float() if invert: x[:, 0, :, :] /= 2.55 x[:, 1, :, :] -= 128 x[:, 2, :, :] -= 128 else: x[:, 0, :, :] *= 2.55 x[:, 1, :, :] += 128 x[:, 2, :, :] += 128 return x def calculate_psnr(img1, img2): # img1 and img2 have range [0, 255] img1 = img1.astype(np.float64) img2 = img2.astype(np.float64) mse = np.mean((img1 - img2)**2) if mse == 0: return float('inf') return 20 * math.log10(255.0 / math.sqrt(mse)) def calculate_fpsnr(fmse): return 10 * math.log10(255.0 / (fmse + 1e-8)) def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1), bit=8): ''' Converts a torch Tensor into an image Numpy array Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) ''' norm = float(2**bit) - 1 # print('before', tensor[:,:,0].max(), tensor[:,:,0].min(), '\t', tensor[:,:,1].max(), tensor[:,:,1].min(), '\t', tensor[:,:,2].max(), tensor[:,:,2].min()) tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp # print('clamp ', tensor[:,:,0].max(), tensor[:,:,0].min(), '\t', tensor[:,:,1].max(), tensor[:,:,1].min(), '\t', tensor[:,:,2].max(), tensor[:,:,2].min()) tensor = (tensor - min_max[0]) / \ (min_max[1] - min_max[0]) # to range [0,1] n_dim = tensor.dim() if n_dim == 4: n_img = len(tensor) img_np = make_grid(tensor, nrow=int( math.sqrt(n_img)), normalize=False).numpy() img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR elif n_dim == 3: img_np = tensor.numpy() img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR elif n_dim == 2: img_np = tensor.numpy() else: raise TypeError( 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) if out_type == np.uint8: # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. img_np = (img_np * norm).round() return img_np.astype(out_type) def rgb_to_lab(image: torch.Tensor) -> torch.Tensor: r"""Convert a RGB image to Lab. .. image:: _static/img/rgb_to_lab.png The input RGB image is assumed to be in the range of :math:`[0, 1]`. Lab color is computed using the D65 illuminant and Observer 2. Args: image: RGB Image to be converted to Lab with shape :math:`(*, 3, H, W)`. Returns: Lab version of the image with shape :math:`(*, 3, H, W)`. The L channel values are in the range 0..100. a and b are in the range -128..127. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = rgb_to_lab(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError( f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") # Convert from sRGB to Linear RGB lin_rgb = rgb_to_linear_rgb(image) xyz_im: torch.Tensor = rgb_to_xyz(lin_rgb) # normalize for D65 white point xyz_ref_white = torch.tensor( [0.95047, 1.0, 1.08883], device=xyz_im.device, dtype=xyz_im.dtype)[..., :, None, None] xyz_normalized = torch.div(xyz_im, xyz_ref_white) threshold = 0.008856 power = torch.pow(xyz_normalized.clamp(min=threshold), 1 / 3.0) scale = 7.787 * xyz_normalized + 4.0 / 29.0 xyz_int = torch.where(xyz_normalized > threshold, power, scale) x: torch.Tensor = xyz_int[..., 0, :, :] y: torch.Tensor = xyz_int[..., 1, :, :] z: torch.Tensor = xyz_int[..., 2, :, :] L: torch.Tensor = (116.0 * y) - 16.0 a: torch.Tensor = 500.0 * (x - y) _b: torch.Tensor = 200.0 * (y - z) out: torch.Tensor = torch.stack([L, a, _b], dim=-3) return out def lab_to_rgb(image: torch.Tensor, clip: bool = True) -> torch.Tensor: r"""Convert a Lab image to RGB. The L channel is assumed to be in the range of :math:`[0, 100]`. a and b channels are in the range of :math:`[-128, 127]`. Args: image: Lab image to be converted to RGB with shape :math:`(*, 3, H, W)`. clip: Whether to apply clipping to insure output RGB values in range :math:`[0, 1]`. Returns: Lab version of the image with shape :math:`(*, 3, H, W)`. The output RGB image are in the range of :math:`[0, 1]`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = lab_to_rgb(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError( f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") L: torch.Tensor = image[..., 0, :, :] a: torch.Tensor = image[..., 1, :, :] _b: torch.Tensor = image[..., 2, :, :] fy = (L + 16.0) / 116.0 fx = (a / 500.0) + fy fz = fy - (_b / 200.0) # if color data out of range: Z < 0 fz = fz.clamp(min=0.0) fxyz = torch.stack([fx, fy, fz], dim=-3) # Convert from Lab to XYZ power = torch.pow(fxyz, 3.0) scale = (fxyz - 4.0 / 29.0) / 7.787 xyz = torch.where(fxyz > 0.2068966, power, scale) # For D65 white point xyz_ref_white = torch.tensor( [0.95047, 1.0, 1.08883], device=xyz.device, dtype=xyz.dtype)[..., :, None, None] xyz_im = xyz * xyz_ref_white rgbs_im: torch.Tensor = xyz_to_rgb(xyz_im) # https://github.com/richzhang/colorization-pytorch/blob/66a1cb2e5258f7c8f374f582acc8b1ef99c13c27/util/util.py#L107 # rgbs_im = torch.where(rgbs_im < 0, torch.zeros_like(rgbs_im), rgbs_im) # Convert from RGB Linear to sRGB rgb_im = linear_rgb_to_rgb(rgbs_im) # Clip to 0,1 https://www.w3.org/Graphics/Color/srgb if clip: rgb_im = torch.clamp(rgb_im, min=0.0, max=1.0) return rgb_im def rgb_to_xyz(image: torch.Tensor) -> torch.Tensor: r"""Convert a RGB image to XYZ. .. image:: _static/img/rgb_to_xyz.png Args: image: RGB Image to be converted to XYZ with shape :math:`(*, 3, H, W)`. Returns: XYZ version of the image with shape :math:`(*, 3, H, W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = rgb_to_xyz(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError( f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") r: torch.Tensor = image[..., 0, :, :] g: torch.Tensor = image[..., 1, :, :] b: torch.Tensor = image[..., 2, :, :] x: torch.Tensor = 0.412453 * r + 0.357580 * g + 0.180423 * b y: torch.Tensor = 0.212671 * r + 0.715160 * g + 0.072169 * b z: torch.Tensor = 0.019334 * r + 0.119193 * g + 0.950227 * b out: torch.Tensor = torch.stack([x, y, z], -3) return out def xyz_to_rgb(image: torch.Tensor) -> torch.Tensor: r"""Convert a XYZ image to RGB. Args: image: XYZ Image to be converted to RGB with shape :math:`(*, 3, H, W)`. Returns: RGB version of the image with shape :math:`(*, 3, H, W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = xyz_to_rgb(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError( f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") x: torch.Tensor = image[..., 0, :, :] y: torch.Tensor = image[..., 1, :, :] z: torch.Tensor = image[..., 2, :, :] r: torch.Tensor = 3.2404813432005266 * x + - \ 1.5371515162713185 * y + -0.4985363261688878 * z g: torch.Tensor = -0.9692549499965682 * x + \ 1.8759900014898907 * y + 0.0415559265582928 * z b: torch.Tensor = 0.0556466391351772 * x + - \ 0.2040413383665112 * y + 1.0573110696453443 * z out: torch.Tensor = torch.stack([r, g, b], dim=-3) return out def rgb_to_linear_rgb(image: torch.Tensor) -> torch.Tensor: r"""Convert an sRGB image to linear RGB. Used in colorspace conversions. .. image:: _static/img/rgb_to_linear_rgb.png Args: image: sRGB Image to be converted to linear RGB of shape :math:`(*,3,H,W)`. Returns: linear RGB version of the image with shape of :math:`(*,3,H,W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = rgb_to_linear_rgb(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError( f"Input size must have a shape of (*, 3, H, W).Got {image.shape}") lin_rgb: torch.Tensor = torch.where(image > 0.04045, torch.pow( ((image + 0.055) / 1.055), 2.4), image / 12.92) return lin_rgb def linear_rgb_to_rgb(image: torch.Tensor) -> torch.Tensor: r"""Convert a linear RGB image to sRGB. Used in colorspace conversions. Args: image: linear RGB Image to be converted to sRGB of shape :math:`(*,3,H,W)`. Returns: sRGB version of the image with shape of shape :math:`(*,3,H,W)`. Example: >>> input = torch.rand(2, 3, 4, 5) >>> output = linear_rgb_to_rgb(input) # 2x3x4x5 """ if not isinstance(image, torch.Tensor): raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}") if len(image.shape) < 3 or image.shape[-3] != 3: raise ValueError( f"Input size must have a shape of (*, 3, H, W).Got {image.shape}") threshold = 0.0031308 rgb: torch.Tensor = torch.where( image > threshold, 1.055 * torch.pow(image.clamp(min=threshold), 1 / 2.4) - 0.055, 12.92 * image ) return rgb def inference_img(model, img, device='cpu'): h, w, _ = img.shape # print(img.shape) if h % 8 != 0 or w % 8 != 0: img = cv2.copyMakeBorder(img, 8-h % 8, 0, 8-w % 8, 0, cv2.BORDER_REFLECT) # print(img.shape) tensor_img = torch.from_numpy(img).permute(2, 0, 1).to(device) input_t = tensor_img input_t = input_t/255.0 normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) input_t = normalize(input_t) input_t = input_t.unsqueeze(0).float() with torch.no_grad(): out = model(input_t) # print("out",out.shape) result = out[0][:, -h:, -w:].cpu().numpy() # print(result.shape) return result[0]