Spaces:
Running
Running
import numpy as np | |
import cv2 | |
import os | |
import tqdm | |
import time | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .util import rgb_to_lab, lab_to_rgb | |
def blend(f, b, a): | |
return f*a + b*(1 - a) | |
class PatchedHarmonizer(nn.Module): | |
def __init__(self, grid_count=1, init_weights=[0.9, 0.1]): | |
super(PatchedHarmonizer, self).__init__() | |
self.eps = 1e-8 | |
# self.weights = torch.nn.Parameter(torch.ones((grid_count, grid_count)), requires_grad=True) | |
# self.grid_weights_ = torch.nn.Parameter(torch.FloatTensor(init_weights), requires_grad=True) | |
self.grid_weights = torch.nn.Parameter( | |
torch.FloatTensor(init_weights), requires_grad=True) | |
# self.weights.retain_graph = True | |
self.grid_count = grid_count | |
def lab_shift(self, x, invert=False): | |
x = x.float() | |
if invert: | |
x[:, 0, :, :] /= 2.55 | |
x[:, 1, :, :] -= 128 | |
x[:, 2, :, :] -= 128 | |
else: | |
x[:, 0, :, :] *= 2.55 | |
x[:, 1, :, :] += 128 | |
x[:, 2, :, :] += 128 | |
return x | |
def get_mean_std(self, img, mask, dim=[2, 3]): | |
sum = torch.sum(img*mask, dim=dim) # (B, C) | |
num = torch.sum(mask, dim=dim) # (B, C) | |
mu = sum / (num + self.eps) | |
mean = mu[:, :, None, None] | |
var = torch.sum(((img - mean)*mask) ** 2, dim=dim) / (num + self.eps) | |
var = var[:, :, None, None] | |
return mean, torch.sqrt(var+self.eps) | |
def compute_patch_statistics(self, lab): | |
means, stds = [], [] | |
bs, dx, dy = lab.shape[0], lab.shape[2] // self.grid_count, lab.shape[3] // self.grid_count | |
for h in range(self.grid_count): | |
cmeans, cstds = [], [] | |
for w in range(self.grid_count): | |
ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy] | |
if h == self.grid_count - 1: | |
ind[1] = None | |
if w == self.grid_count - 1: | |
ind[-1] = None | |
m, v = self.compute_mean_var( | |
lab[:, :, ind[0]:ind[1], ind[2]:ind[3]], dim=[2, 3]) | |
cmeans.append(m) | |
cstds.append(v) | |
means.append(cmeans) | |
stds.append(cstds) | |
return means, stds | |
def compute_mean_var(self, x, dim=[1, 2]): | |
mean = x.float().mean(dim=dim)[:, :, None, None] | |
var = torch.sqrt(x.float().var(dim=dim))[:, :, None, None] | |
return mean, var | |
def forward(self, fg_rgb, bg_rgb, alpha, masked_stats=False): | |
bg_rgb = F.interpolate(bg_rgb, size=( | |
fg_rgb.shape[2:])) # b x C x H x W | |
bg_lab = bg_rgb # self.lab_shift(rgb_to_lab(bg_rgb/255.)) | |
fg_lab = fg_rgb # self.lab_shift(rgb_to_lab(fg_rgb/255.)) | |
if masked_stats: | |
self.bg_global_mean, self.bg_global_var = self.get_mean_std( | |
img=bg_lab, mask=(1-alpha)) | |
self.fg_global_mean, self.fg_global_var = self.get_mean_std( | |
img=fg_lab, mask=torch.ones_like(alpha)) | |
else: | |
self.bg_global_mean, self.bg_global_var = self.compute_mean_var(bg_lab, dim=[ | |
2, 3]) | |
self.fg_global_mean, self.fg_global_var = self.compute_mean_var(fg_lab, dim=[ | |
2, 3]) | |
self.bg_means, self.bg_vars = self.compute_patch_statistics( | |
bg_lab) | |
self.fg_means, self.fg_vars = self.compute_patch_statistics( | |
fg_lab) | |
fg_harm = self.harmonize(fg_lab) | |
# fg_harm = lab_to_rgb(fg_harm) | |
bg = F.interpolate(bg_rgb, size=(fg_rgb.shape[2:]))/255. | |
composite = blend(fg_harm, bg, alpha) | |
return composite, fg_harm | |
def harmonize(self, fg): | |
harmonized = torch.zeros_like(fg) | |
dx = fg.shape[2] // self.grid_count | |
dy = fg.shape[3] // self.grid_count | |
for h in range(self.grid_count): | |
for w in range(self.grid_count): | |
ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy] | |
if h == self.grid_count - 1: | |
ind[1] = None | |
if w == self.grid_count - 1: | |
ind[-1] = None | |
harmonized[:, :, ind[0]:ind[1], ind[2]:ind[3]] = self.normalize_channel( | |
fg[:, :, ind[0]:ind[1], ind[2]:ind[3]], h, w) | |
# harmonized = self.lab_shift(harmonized, invert=True) | |
return harmonized | |
def normalize_channel(self, value, h, w): | |
fg_local_mean, fg_local_var = self.fg_means[h][w], self.fg_vars[h][w] | |
bg_local_mean, bg_local_var = self.bg_means[h][w], self.bg_vars[h][w] | |
fg_global_mean, fg_global_var = self.fg_global_mean, self.fg_global_var | |
bg_global_mean, bg_global_var = self.bg_global_mean, self.bg_global_var | |
# global2global normalization | |
zeroed_mean = value - fg_global_mean | |
# (fg_v * div_global_v + (1-fg_v) * div_v) | |
scaled_var = zeroed_mean * (bg_global_var/(fg_global_var + self.eps)) | |
normalized_global = scaled_var + bg_global_mean | |
# local2local normalization | |
zeroed_mean = value - fg_local_mean | |
# (fg_v * div_global_v + (1-fg_v) * div_v) | |
scaled_var = zeroed_mean * (bg_local_var/(fg_local_var + self.eps)) | |
normalized_local = scaled_var + bg_local_mean | |
return self.grid_weights[0]*normalized_local + self.grid_weights[1]*normalized_global | |
def normalize_fg(self, value): | |
zeroed_mean = value - \ | |
(self.fg_local_mean * | |
self.grid_weights[None, None, :, :, None, None]).sum().squeeze() | |
# (fg_v * div_global_v + (1-fg_v) * div_v) | |
scaled_var = zeroed_mean * \ | |
(self.bg_global_var/(self.fg_global_var + self.eps)) | |
normalized_lg = scaled_var + \ | |
(self.bg_local_mean * | |
self.grid_weights[None, None, :, :, None, None]).sum().squeeze() | |
return normalized_lg | |
class PatchNormalizer(nn.Module): | |
def __init__(self, in_channels=3, eps=1e-7, grid_count=1, weights=[0.5, 0.5], init_value=1e-2): | |
super(PatchNormalizer, self).__init__() | |
self.grid_count = grid_count | |
self.eps = eps | |
self.weights = nn.Parameter( | |
torch.FloatTensor(weights), requires_grad=True) | |
self.fg_var = nn.Parameter( | |
init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True) | |
self.fg_bias = nn.Parameter( | |
init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True) | |
self.patched_fg_var = nn.Parameter( | |
init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True) | |
self.patched_fg_bias = nn.Parameter( | |
init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True) | |
self.bg_var = nn.Parameter( | |
init_value * torch.ones(in_channels)[None, :, None, None], requires_grad=True) | |
self.bg_bias = nn.Parameter( | |
init_value * torch.zeros(in_channels)[None, :, None, None], requires_grad=True) | |
self.grid_weights = torch.nn.Parameter(torch.ones((in_channels, grid_count, grid_count))[ | |
None, :, :, :] / (grid_count*grid_count*in_channels), requires_grad=True) | |
def local_normalization(self, value): | |
zeroed_mean = value - \ | |
(self.fg_local_mean * | |
self.grid_weights[None, None, :, :, None, None]).sum().squeeze() | |
# (fg_v * div_global_v + (1-fg_v) * div_v) | |
scaled_var = zeroed_mean * \ | |
(self.bg_global_var/(self.fg_global_var + self.eps)) | |
normalized_lg = scaled_var + \ | |
(self.bg_local_mean * | |
self.grid_weights[None, None, :, :, None, None]).sum().squeeze() | |
return normalized_lg | |
def get_mean_std(self, img, mask, dim=[2, 3]): | |
sum = torch.sum(img*mask, dim=dim) # (B, C) | |
num = torch.sum(mask, dim=dim) # (B, C) | |
mu = sum / (num + self.eps) | |
mean = mu[:, :, None, None] | |
var = torch.sum(((img - mean)*mask) ** 2, dim=dim) / (num + self.eps) | |
var = var[:, :, None, None] | |
return mean, torch.sqrt(var+self.eps) | |
def compute_patch_statistics(self, img, mask): | |
means, stds = [], [] | |
bs, dx, dy = img.shape[0], img.shape[2] // self.grid_count, img.shape[3] // self.grid_count | |
for h in range(self.grid_count): | |
cmeans, cstds = [], [] | |
for w in range(self.grid_count): | |
ind = [h*dx, (h+1)*dx, w*dy, (w+1)*dy] | |
if h == self.grid_count - 1: | |
ind[1] = None | |
if w == self.grid_count - 1: | |
ind[-1] = None | |
m, v = self.get_mean_std( | |
img[:, :, ind[0]:ind[1], ind[2]:ind[3]], mask[:, :, ind[0]:ind[1], ind[2]:ind[3]], dim=[2, 3]) | |
cmeans.append(m.reshape(m.shape[:2])) | |
cstds.append(v.reshape(v.shape[:2])) | |
means.append(torch.stack(cmeans)) | |
stds.append(torch.stack(cstds)) | |
return torch.stack(means), torch.stack(stds) | |
def compute_mean_var(self, x, dim=[2, 3]): | |
mean = x.float().mean(dim=dim) | |
var = torch.sqrt(x.float().var(dim=dim)) | |
return mean, var | |
def forward(self, fg, bg, mask): | |
self.local_means, self.local_vars = self.compute_patch_statistics( | |
bg, (1-mask)) | |
bg_mean, bg_var = self.get_mean_std(bg, 1 - mask) | |
zeroed_mean = (bg - bg_mean) | |
unscaled = zeroed_mean / bg_var | |
bg_normalized = unscaled * self.bg_var + self.bg_bias | |
fg_mean, fg_var = self.get_mean_std(fg, mask) | |
zeroed_mean = fg - fg_mean | |
unscaled = zeroed_mean / fg_var | |
mean_patched_back = (self.local_means.permute( | |
2, 3, 0, 1)*self.grid_weights).sum(dim=[2, 3])[:, :, None, None] | |
normalized = unscaled * bg_var + bg_mean | |
patch_normalized = unscaled * bg_var + mean_patched_back | |
fg_normalized = normalized * self.fg_var + self.fg_bias | |
fg_patch_normalized = patch_normalized * \ | |
self.patched_fg_var + self.patched_fg_bias | |
fg_result = self.weights[0] * fg_normalized + \ | |
self.weights[1] * fg_patch_normalized | |
composite = blend(fg_result, bg_normalized, mask) | |
return composite | |