import numpy as np import cv2 import os import tqdm import time import torch import torch.nn as nn import torch.nn.functional as F from .util import rgb_to_lab, lab_to_rgb def blend(f, b, a): return f*a + b*(1 - a) class PatchedHarmonizer(nn.Module): def __init__(self, grid_count=1, init_weights=[0.9, 0.1]): super(PatchedHarmonizer, self).__init__() self.eps = 1e-8 # self.weights = torch.nn.Parameter(torch.ones((grid_count, grid_count)), requires_grad=True) # self.grid_weights_ = torch.nn.Parameter(torch.FloatTensor(init_weights), requires_grad=True) self.grid_weights = torch.nn.Parameter( torch.FloatTensor(init_weights), requires_grad=True) # self.weights.retain_graph = True self.grid_count = grid_count def lab_shift(self, x, invert=False): x = x.float() if invert: x[:, 0, :, :] /= 2.55 x[:, 1, :, :] -= 128 x[:, 2, :, :] -= 128 else: x[:, 0, :, :] *= 2.55 x[:, 1, :, :] += 128 x[:, 2, :, :] += 128 return x def get_mean_std(self, img, mask, dim=[2, 3]): sum = torch.sum(img*mask, dim=dim) # (B, C) num = torch.sum(mask, dim=dim) # (B, C) mu = sum / (num + self.eps) mean = mu[:, :, None, None] var = torch.sum(((img - mean)*mask) ** 2, dim=dim) / (num + self.eps) var = var[:, :, None, None] return mean, torch.sqrt(var+self.eps) def compute_patch_statistics(self, lab): means, stds = [], [] bs, dx, dy = lab.shape[0], lab.shape[2] // self.grid_count, lab.shape[3] // self.grid_count for h in range(self.grid_count): cmeans, cstds = [], [] for w in range(self.grid_count): ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy] if h == self.grid_count - 1: ind[1] = None if w == self.grid_count - 1: ind[-1] = None m, v = self.compute_mean_var( lab[:, :, ind[0]:ind[1], ind[2]:ind[3]], dim=[2, 3]) cmeans.append(m) cstds.append(v) means.append(cmeans) stds.append(cstds) return means, stds def compute_mean_var(self, x, dim=[1, 2]): mean = x.float().mean(dim=dim)[:, :, None, None] var = torch.sqrt(x.float().var(dim=dim))[:, :, None, None] return mean, var def forward(self, fg_rgb, bg_rgb, alpha, masked_stats=False): bg_rgb = F.interpolate(bg_rgb, size=( fg_rgb.shape[2:])) # b x C x H x W bg_lab = bg_rgb # self.lab_shift(rgb_to_lab(bg_rgb/255.)) fg_lab = fg_rgb # self.lab_shift(rgb_to_lab(fg_rgb/255.)) if masked_stats: self.bg_global_mean, self.bg_global_var = self.get_mean_std( img=bg_lab, mask=(1-alpha)) self.fg_global_mean, self.fg_global_var = self.get_mean_std( img=fg_lab, mask=torch.ones_like(alpha)) else: self.bg_global_mean, self.bg_global_var = self.compute_mean_var(bg_lab, dim=[ 2, 3]) self.fg_global_mean, self.fg_global_var = self.compute_mean_var(fg_lab, dim=[ 2, 3]) self.bg_means, self.bg_vars = self.compute_patch_statistics( bg_lab) self.fg_means, self.fg_vars = self.compute_patch_statistics( fg_lab) fg_harm = self.harmonize(fg_lab) # fg_harm = lab_to_rgb(fg_harm) bg = F.interpolate(bg_rgb, size=(fg_rgb.shape[2:]))/255. composite = blend(fg_harm, bg, alpha) return composite, fg_harm def harmonize(self, fg): harmonized = torch.zeros_like(fg) dx = fg.shape[2] // self.grid_count dy = fg.shape[3] // self.grid_count for h in range(self.grid_count): for w in range(self.grid_count): ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy] if h == self.grid_count - 1: ind[1] = None if w == self.grid_count - 1: ind[-1] = None harmonized[:, :, ind[0]:ind[1], ind[2]:ind[3]] = self.normalize_channel( fg[:, :, ind[0]:ind[1], ind[2]:ind[3]], h, w) # harmonized = self.lab_shift(harmonized, invert=True) return harmonized def normalize_channel(self, value, h, w): fg_local_mean, fg_local_var = self.fg_means[h][w], self.fg_vars[h][w] bg_local_mean, bg_local_var = self.bg_means[h][w], self.bg_vars[h][w] fg_global_mean, fg_global_var = self.fg_global_mean, self.fg_global_var bg_global_mean, bg_global_var = self.bg_global_mean, self.bg_global_var # global2global normalization zeroed_mean = value - fg_global_mean # (fg_v * div_global_v + (1-fg_v) * div_v) scaled_var = zeroed_mean * (bg_global_var/(fg_global_var + self.eps)) normalized_global = scaled_var + bg_global_mean # local2local normalization zeroed_mean = value - fg_local_mean # (fg_v * div_global_v + (1-fg_v) * div_v) scaled_var = zeroed_mean * (bg_local_var/(fg_local_var + self.eps)) normalized_local = scaled_var + bg_local_mean return self.grid_weights[0]*normalized_local + self.grid_weights[1]*normalized_global def normalize_fg(self, value): zeroed_mean = value - \ (self.fg_local_mean * self.grid_weights[None, None, :, :, None, None]).sum().squeeze() # (fg_v * div_global_v + (1-fg_v) * div_v) scaled_var = zeroed_mean * \ (self.bg_global_var/(self.fg_global_var + self.eps)) normalized_lg = scaled_var + \ (self.bg_local_mean * self.grid_weights[None, None, :, :, None, None]).sum().squeeze() return normalized_lg class PatchNormalizer(nn.Module): def __init__(self, in_channels=3, eps=1e-7, grid_count=1, weights=[0.5, 0.5], init_value=1e-2): super(PatchNormalizer, self).__init__() self.grid_count = grid_count self.eps = eps self.weights = nn.Parameter( torch.FloatTensor(weights), requires_grad=True) self.fg_var = nn.Parameter( init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True) self.fg_bias = nn.Parameter( init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True) self.patched_fg_var = nn.Parameter( init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True) self.patched_fg_bias = nn.Parameter( init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True) self.bg_var = nn.Parameter( init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True) self.bg_bias = nn.Parameter( init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True) self.grid_weights = torch.nn.Parameter(torch.ones((in_channels, grid_count, grid_count))[ None, :, :, :] / (grid_count*grid_count*in_channels), requires_grad=True) def local_normalization(self, value): zeroed_mean = value - \ (self.fg_local_mean * self.grid_weights[None, None, :, :, None, None]).sum().squeeze() # (fg_v * div_global_v + (1-fg_v) * div_v) scaled_var = zeroed_mean * \ (self.bg_global_var/(self.fg_global_var + self.eps)) normalized_lg = scaled_var + \ (self.bg_local_mean * self.grid_weights[None, None, :, :, None, None]).sum().squeeze() return normalized_lg def get_mean_std(self, img, mask, dim=[2, 3]): sum = torch.sum(img*mask, dim=dim) # (B, C) num = torch.sum(mask, dim=dim) # (B, C) mu = sum / (num + self.eps) mean = mu[:, :, None, None] var = torch.sum(((img - mean)*mask) ** 2, dim=dim) / (num + self.eps) var = var[:, :, None, None] return mean, torch.sqrt(var+self.eps) def compute_patch_statistics(self, img, mask): means, stds = [], [] bs, dx, dy = img.shape[0], img.shape[2] // self.grid_count, img.shape[3] // self.grid_count for h in range(self.grid_count): cmeans, cstds = [], [] for w in range(self.grid_count): ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy] if h == self.grid_count - 1: ind[1] = None if w == self.grid_count - 1: ind[-1] = None m, v = self.get_mean_std( img[:, :, ind[0]:ind[1], ind[2]:ind[3]], mask[:, :, ind[0]:ind[1], ind[2]:ind[3]], dim=[2, 3]) cmeans.append(m.reshape(m.shape[:2])) cstds.append(v.reshape(v.shape[:2])) means.append(torch.stack(cmeans)) stds.append(torch.stack(cstds)) return torch.stack(means), torch.stack(stds) def compute_mean_var(self, x, dim=[2, 3]): mean = x.float().mean(dim=dim) var = torch.sqrt(x.float().var(dim=dim)) return mean, var def forward(self, fg, bg, mask): self.local_means, self.local_vars = self.compute_patch_statistics( bg, (1-mask)) bg_mean, bg_var = self.get_mean_std(bg, 1 - mask) zeroed_mean = (bg - bg_mean) unscaled = zeroed_mean / bg_var bg_normalized = unscaled * self.bg_var + self.bg_bias fg_mean, fg_var = self.get_mean_std(fg, mask) zeroed_mean = fg - fg_mean unscaled = zeroed_mean / fg_var mean_patched_back = (self.local_means.permute( 2, 3, 0, 1)*self.grid_weights).sum(dim=[2, 3])[:, :, None, None] normalized = unscaled * bg_var + bg_mean patch_normalized = unscaled * bg_var + mean_patched_back fg_normalized = normalized * self.fg_var + self.fg_bias fg_patch_normalized = patch_normalized * \ self.patched_fg_var + self.patched_fg_bias fg_result = self.weights[0] * fg_normalized + \ self.weights[1] * fg_patch_normalized composite = blend(fg_result, bg_normalized, mask) return composite