from matplotlib import pyplot as plt # from shtools import shReconstructSignal from torchvision import transforms, utils # from torchvision.ops import SqueezeExcitation from torch.utils.data import Dataset import torch.nn.functional as F import torch.nn as nn import torch import math import cv2 import numpy as np from .normalizer import PatchNormalizer, PatchedHarmonizer from .util import rgb_to_lab, lab_to_rgb, lab_shift # from shtools import * # from color_converters import luv_to_rgb, rgb_to_luv # from skimage import io, transform ''' Input (256,512,3) ''' def inpaint_bg(comp, mask, dim=[2, 3]): """ inpaint bg for ihd Args: comp (torch.float): [0:1] mask (torch.float): [0:1] """ back = comp * (1-mask) # *255 sum = torch.sum(back, dim=dim) # (B, C) num = torch.sum((1-mask), dim=dim) # (B, C) mu = sum / (num) mean = mu[:, :, None, None] back = back + mask * mean return back class ConvTransposeUp(nn.Sequential): def __init__(self, in_channels, out_channels, kernel_size=4, padding=1, stride=2, activation=None): super().__init__( nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, stride=stride), activation() if activation is not None else nn.Identity(), ) class UpsampleShuffle(nn.Sequential): def __init__(self, in_channels, out_channels, activation=True): super().__init__( nn.Conv2d(in_channels, out_channels * 4, kernel_size=1), nn.GELU() if activation else nn.Identity(), nn.PixelShuffle(2) ) def reset_parameters(self): init_subpixel(self[0].weight) nn.init.zeros_(self[0].bias) class UpsampleResize(nn.Sequential): def __init__(self, in_channels, out_channels, out_size=None, activation=None, scale_factor=2., mode='bilinear'): super().__init__( nn.Upsample(scale_factor=scale_factor, mode=mode) if out_size is None else nn.Upsample( out_size, mode=mode), nn.ReflectionPad2d(1), nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=0), activation() if activation is not None else nn.Identity(), ) def conv_bn(in_, out_, kernel_size=3, stride=1, padding=1, activation=nn.ReLU, normalization=nn.InstanceNorm2d): return nn.Sequential( nn.Conv2d(in_, out_, kernel_size, stride=stride, padding=padding), normalization(out_) if normalization is not None else nn.Identity(), activation(), ) def init_subpixel(weight): co, ci, h, w = weight.shape co2 = co // 4 # initialize sub kernel k = torch.empty([c02, ci, h, w]) nn.init.kaiming_uniform_(k) # repeat 4 times k = k.repeat_interleave(4, dim=0) weight.data.copy_(k) class DownsampleShuffle(nn.Sequential): def __init__(self, in_channels): assert in_channels % 4 == 0 super().__init__( nn.Conv2d(in_channels, in_channels // 4, kernel_size=1), nn.ReLU(), nn.PixelUnshuffle(2) ) def reset_parameters(self): init_subpixel(self[0].weight) nn.init.zeros_(self[0].bias) def conv_bn_elu(in_, out_, kernel_size=3, stride=1, padding=True): # conv layer with ELU activation function pad = int(kernel_size/2) if padding is False: pad = 0 return nn.Sequential( nn.Conv2d(in_, out_, kernel_size, stride=stride, padding=pad), nn.ELU(), ) class Inference_Data(Dataset): def __init__(self, img_path): self.input_img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) self.input_img = cv2.resize( self.input_img, (512, 256), interpolation=cv2.INTER_CUBIC) self.to_tensor = transforms.ToTensor() self.data_len = 1 def __getitem__(self, index): self.tensor_img = self.to_tensor(self.input_img) return self.tensor_img def __len__(self): return self.data_len class MyAdaptiveMaxPool2d(nn.Module): def __init__(self, sz=None): super().__init__() def forward(self, x): inp_size = x.size() return nn.functional.max_pool2d(input=x, kernel_size=(inp_size[2], inp_size[3])) class SEBlock(nn.Module): def __init__(self, channel, reducation=8): super(SEBlock, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(channel, channel//reducation), nn.ReLU(inplace=True), nn.Linear(channel//reducation, channel), nn.Sigmoid()) def forward(self, x, aux_inp=None): b, c, w, h = x.size() def scale(x): return (x - x.min()) / (x.max() - x.min() + 1e-8) y1 = self.avg_pool(x).view(b, c) y = self.fc(y1).view(b, c, 1, 1) r = x*y if aux_inp is not None: aux_weitghts = MyAdaptiveMaxPool2d( aux_inp.shape[-1]//8)(aux_inp) aux_weitghts = nn.Sigmoid()(aux_weitghts.mean(1, keepdim=True)) tmp = x*aux_weitghts tmp_img = (tmp - tmp.min()) / (tmp.max() - tmp.min()) r += tmp return r class ConvTransposeUp(nn.Sequential): def __init__(self, in_channels, out_channels, norm, kernel_size=3, stride=2, padding=1, activation=None): super().__init__( nn.ConvTranspose2d(in_channels, out_channels, # output_padding=output_padding, dilation=dilation kernel_size=kernel_size, padding=padding, stride=stride, ), norm(out_channels) if norm is not None else nn.Identity(), activation() if activation is not None else nn.Identity(), ) class SkipConnect(nn.Module): """docstring for RegionalSkipConnect""" def __init__(self, channel): super(SkipConnect, self).__init__() self.rconv = nn.Conv2d(channel*2, channel, 3, padding=1, bias=False) def forward(self, feature): return F.relu(self.rconv(feature)) class AttentionBlock(nn.Module): def __init__(self, in_channels): super(AttentionBlock, self).__init__() self.attn = nn.Sequential( nn.Conv2d(in_channels * 2, in_channels * 2, kernel_size=1), nn.Sigmoid() ) def forward(self, x): return self.attn(x) class PatchHarmonizerBlock(nn.Module): def __init__(self, in_channels=3, grid_count=5): super(PatchHarmonizerBlock, self).__init__() self.patch_harmonizer = PatchedHarmonizer(grid_count=grid_count) self.head = conv_bn(in_channels*2, in_channels, kernel_size=3, padding=1, normalization=None) def forward(self, fg, bg, mask): fg_harm, _ = self.patch_harmonizer(fg, bg, mask) return self.head(torch.cat([fg, fg_harm], 1)) class PHNet(nn.Module): def __init__(self, enc_sizes=[3, 16, 32, 64, 128, 256, 512], skips=True, grid_count=[10, 5, 1], init_weights=[0.5, 0.5], init_value=0.8): super(PHNet, self).__init__() self.skips = skips self.feature_extractor = PatchHarmonizerBlock( in_channels=enc_sizes[0], grid_count=grid_count[1]) self.encoder = nn.ModuleList([ conv_bn(enc_sizes[0], enc_sizes[1], kernel_size=4, stride=2), conv_bn(enc_sizes[1], enc_sizes[2], kernel_size=3, stride=1), conv_bn(enc_sizes[2], enc_sizes[3], kernel_size=4, stride=2), conv_bn(enc_sizes[3], enc_sizes[4], kernel_size=3, stride=1), conv_bn(enc_sizes[4], enc_sizes[5], kernel_size=4, stride=2), conv_bn(enc_sizes[5], enc_sizes[6], kernel_size=3, stride=1), ]) dec_ins = enc_sizes[::-1] dec_sizes = enc_sizes[::-1] self.start_level = len(dec_sizes) - len(grid_count) self.normalizers = nn.ModuleList([ PatchNormalizer(in_channels=dec_sizes[self.start_level+i], grid_count=count, weights=init_weights, eps=1e-7, init_value=init_value) for i, count in enumerate(grid_count) ]) self.decoder = nn.ModuleList([ ConvTransposeUp( dec_ins[0], dec_sizes[1], norm=nn.BatchNorm2d, kernel_size=3, stride=1, activation=nn.LeakyReLU), ConvTransposeUp( dec_ins[1], dec_sizes[2], norm=nn.BatchNorm2d, kernel_size=4, stride=2, activation=nn.LeakyReLU), ConvTransposeUp( dec_ins[2], dec_sizes[3], norm=nn.BatchNorm2d, kernel_size=3, stride=1, activation=nn.LeakyReLU), ConvTransposeUp( dec_ins[3], dec_sizes[4], norm=None, kernel_size=4, stride=2, activation=nn.LeakyReLU), ConvTransposeUp( dec_ins[4], dec_sizes[5], norm=None, kernel_size=3, stride=1, activation=nn.LeakyReLU), ConvTransposeUp( dec_ins[5], 3, norm=None, kernel_size=4, stride=2, activation=None), ]) self.skip = nn.ModuleList([ SkipConnect(x) for x in dec_ins ]) self.SE_block = SEBlock(enc_sizes[6]) def forward(self, img, mask): x = img enc_outs = [x] x_harm = self.feature_extractor(x*mask, x*(1-mask), mask) # x = x_harm masks = [mask] for i, down_layer in enumerate(self.encoder): x = down_layer(x) scale_factor = 1. / (pow(2, 1 - i % 2)) masks.append(F.interpolate(masks[-1], scale_factor=scale_factor)) enc_outs.append(x) x = self.SE_block(x, aux_inp=x_harm) masks = masks[::-1] for i, (up_layer, enc_out) in enumerate(zip(self.decoder, enc_outs[::-1])): if i >= self.start_level: enc_out = self.normalizers[i - self.start_level](enc_out, enc_out, masks[i]) x = torch.cat([x, enc_out], 1) x = self.skip[i](x) x = up_layer(x) harmonized = F.sigmoid(x) return harmonized def set_requires_grad(self, modules=["encoder", "sh_head", "resquare", "decoder"], value=False): for module in modules: attr = getattr(self, module, None) if attr is not None: attr.requires_grad_(value)