Spaces:
Running
on
Zero
Running
on
Zero
EPS = 1e-7 | |
import kornia | |
from typing import Dict, Iterator, List, Optional, Tuple, Union | |
import torchvision | |
from guided_diffusion import dist_util, logger | |
from pdb import set_trace as st | |
from torch.nn import functional as F | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import lpips | |
from . import * | |
from .sdfstudio_losses import ScaleAndShiftInvariantLoss | |
from ldm.util import default, instantiate_from_config | |
from .vqperceptual import hinge_d_loss, vanilla_d_loss | |
from torch.autograd import Variable | |
from math import exp | |
def gaussian(window_size, sigma): | |
gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) | |
return gauss / gauss.sum() | |
def create_window(window_size, channel): | |
_1D_window = gaussian(window_size, 1.5).unsqueeze(1) | |
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) | |
window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) | |
return window | |
def _ssim(img1, img2, window, window_size, channel, size_average=True): | |
mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) | |
mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) | |
mu1_sq = mu1.pow(2) | |
mu2_sq = mu2.pow(2) | |
mu1_mu2 = mu1 * mu2 | |
sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq | |
sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq | |
sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 | |
C1 = 0.01 ** 2 | |
C2 = 0.03 ** 2 | |
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) | |
if size_average: | |
return ssim_map.mean() | |
else: | |
return ssim_map.mean(1).mean(1).mean(1) | |
def weights_init(m): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
nn.init.normal_(m.weight.data, 0.0, 0.02) | |
elif classname.find("BatchNorm") != -1: | |
nn.init.normal_(m.weight.data, 1.0, 0.02) | |
nn.init.constant_(m.bias.data, 0) | |
# Main loss function used for ZoeDepth. Copy/paste from AdaBins repo (https://github.com/shariqfarooq123/AdaBins/blob/0952d91e9e762be310bb4cd055cbfe2448c0ce20/loss.py#L7) | |
def extract_key(prediction, key): | |
if isinstance(prediction, dict): | |
return prediction[key] | |
return prediction | |
class SILogLoss(nn.Module): | |
"""SILog loss (pixel-wise)""" | |
def __init__(self, beta=0.15): | |
super(SILogLoss, self).__init__() | |
self.name = 'SILog' | |
self.beta = beta | |
def forward(self, | |
input, | |
target, | |
mask=None, | |
interpolate=True, | |
return_interpolated=False): | |
# input = extract_key(input, KEY_OUTPUT) | |
if input.shape[-1] != target.shape[-1] and interpolate: | |
input = nn.functional.interpolate(input, | |
target.shape[-2:], | |
mode='bilinear', | |
align_corners=True) | |
intr_input = input | |
else: | |
intr_input = input | |
if target.ndim == 3: | |
target = target.unsqueeze(1) | |
if mask is not None: | |
if mask.ndim == 3: | |
mask = mask.unsqueeze(1) | |
input = input[mask] | |
target = target[mask] | |
# with torch.amp.autocast(enabled=False): # amp causes NaNs in this loss function | |
alpha = 1e-7 | |
g = torch.log(input + alpha) - torch.log(target + alpha) | |
# n, c, h, w = g.shape | |
# norm = 1/(h*w) | |
# Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2 | |
Dg = torch.var(g) + self.beta * torch.pow(torch.mean(g), 2) | |
loss = 10 * torch.sqrt(Dg) | |
if torch.isnan(loss): | |
print("Nan SILog loss") | |
print("input:", input.shape) | |
print("target:", target.shape) | |
print("G", torch.sum(torch.isnan(g))) | |
print("Input min max", torch.min(input), torch.max(input)) | |
print("Target min max", torch.min(target), torch.max(target)) | |
print("Dg", torch.isnan(Dg)) | |
print("loss", torch.isnan(loss)) | |
if not return_interpolated: | |
return loss | |
return loss, intr_input | |
def get_outnorm(x: torch.Tensor, out_norm: str = '') -> torch.Tensor: | |
""" Common function to get a loss normalization value. Can | |
normalize by either the batch size ('b'), the number of | |
channels ('c'), the image size ('i') or combinations | |
('bi', 'bci', etc) | |
""" | |
# b, c, h, w = x.size() | |
img_shape = x.shape | |
if not out_norm: | |
return 1 | |
norm = 1 | |
if 'b' in out_norm: | |
# normalize by batch size | |
# norm /= b | |
norm /= img_shape[0] | |
if 'c' in out_norm: | |
# normalize by the number of channels | |
# norm /= c | |
norm /= img_shape[-3] | |
if 'i' in out_norm: | |
# normalize by image/map size | |
# norm /= h*w | |
norm /= img_shape[-1] * img_shape[-2] | |
return norm | |
class CharbonnierLoss(torch.nn.Module): | |
"""Charbonnier Loss (L1)""" | |
def __init__(self, eps=1e-6, out_norm: str = 'bci'): | |
super(CharbonnierLoss, self).__init__() | |
self.eps = eps | |
self.out_norm = out_norm | |
def forward(self, x, y): | |
norm = get_outnorm(x, self.out_norm) | |
loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2)) | |
return loss * norm | |
def feature_vae_loss(feature): | |
# kld = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) | |
# feature dim: B C H W | |
mu = feature.mean(1) | |
var = feature.var(1) | |
log_var = torch.log(var) | |
kld = torch.mean(-0.5 * torch.sum(1 + log_var - mu**2 - var, dim=1), dim=0) | |
return kld | |
def kl_coeff(step, total_step, constant_step, min_kl_coeff, max_kl_coeff): | |
# return max(min(max_kl_coeff * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff) | |
kl_lambda = max( | |
min( | |
min_kl_coeff + (max_kl_coeff - min_kl_coeff) * | |
(step - constant_step) / total_step, max_kl_coeff), min_kl_coeff) | |
return torch.tensor(kl_lambda, device=dist_util.dev()) | |
def depth_smoothness_loss(alpha_pred, depth_pred): | |
# from PesonNeRF paper. | |
# all Tensor shape B 1 H W | |
geom_loss = ( | |
alpha_pred[..., :-1] * alpha_pred[..., 1:] * ( | |
depth_pred[..., :-1] - depth_pred[..., 1:] # W dim | |
).square()).mean() # mean of ([8, 1, 64, 63]) | |
geom_loss += (alpha_pred[..., :-1, :] * alpha_pred[..., 1:, :] * | |
(depth_pred[..., :-1, :] - depth_pred[..., 1:, :]).square() | |
).mean() # H dim, ([8, 1, 63, 64]) | |
return geom_loss | |
# https://github.com/elliottwu/unsup3d/blob/master/unsup3d/networks.py#L140 | |
class LPIPSLoss(torch.nn.Module): | |
def __init__( | |
self, | |
loss_weight=1.0, | |
use_input_norm=True, | |
range_norm=True, | |
# n1p1_input=True, | |
): | |
super(LPIPSLoss, self).__init__() | |
# self.perceptual = lpips.LPIPS(net="alex", spatial=False).eval() | |
self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() | |
self.loss_weight = loss_weight | |
self.use_input_norm = use_input_norm | |
self.range_norm = range_norm | |
# if self.use_input_norm: | |
# # the mean is for image with range [0, 1] | |
# self.register_buffer( | |
# 'mean', | |
# torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
# # the std is for image with range [0, 1] | |
# self.register_buffer( | |
# 'std', | |
# torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
def forward(self, pred, target, conf_sigma_percl=None): | |
# st() | |
# ! add large image support, only sup 128x128 patch | |
lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) | |
return self.loss_weight * lpips_loss.mean() | |
# mask-aware perceptual loss | |
class PerceptualLoss(nn.Module): | |
def __init__(self, requires_grad=False): | |
super(PerceptualLoss, self).__init__() | |
mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406]) | |
std_rgb = torch.FloatTensor([0.229, 0.224, 0.225]) | |
self.register_buffer('mean_rgb', mean_rgb) | |
self.register_buffer('std_rgb', std_rgb) | |
vgg_pretrained_features = torchvision.models.vgg16( | |
pretrained=True).features | |
self.slice1 = nn.Sequential() | |
self.slice2 = nn.Sequential() | |
self.slice3 = nn.Sequential() | |
self.slice4 = nn.Sequential() | |
for x in range(4): | |
self.slice1.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(4, 9): | |
self.slice2.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(9, 16): | |
self.slice3.add_module(str(x), vgg_pretrained_features[x]) | |
for x in range(16, 23): | |
self.slice4.add_module(str(x), vgg_pretrained_features[x]) | |
if not requires_grad: | |
for param in self.parameters(): | |
param.requires_grad = False | |
def normalize(self, x): | |
out = x / 2 + 0.5 | |
out = (out - self.mean_rgb.view(1, 3, 1, 1)) / self.std_rgb.view( | |
1, 3, 1, 1) | |
return out | |
def __call__(self, im1, im2, mask=None, conf_sigma=None): | |
im = torch.cat([im1, im2], 0) | |
im = self.normalize(im) # normalize input | |
## compute features | |
feats = [] | |
f = self.slice1(im) | |
feats += [torch.chunk(f, 2, dim=0)] | |
f = self.slice2(f) | |
feats += [torch.chunk(f, 2, dim=0)] | |
f = self.slice3(f) | |
feats += [torch.chunk(f, 2, dim=0)] | |
f = self.slice4(f) | |
feats += [torch.chunk(f, 2, dim=0)] | |
losses = [] | |
for f1, f2 in feats[2:3]: # use relu3_3 features only | |
loss = (f1 - f2)**2 | |
if conf_sigma is not None: | |
loss = loss / (2 * conf_sigma**2 + EPS) + (conf_sigma + | |
EPS).log() | |
if mask is not None: | |
b, c, h, w = loss.shape | |
_, _, hm, wm = mask.shape | |
sh, sw = hm // h, wm // w | |
mask0 = nn.functional.avg_pool2d(mask, | |
kernel_size=(sh, sw), | |
stride=(sh, | |
sw)).expand_as(loss) | |
loss = (loss * mask0).sum() / mask0.sum() | |
else: | |
loss = loss.mean() | |
losses += [loss] | |
return sum(losses) | |
# add confidence support, unsup3d version | |
def photometric_loss_laplace(im1, im2, mask=None, conf_sigma=None): | |
loss = (im1 - im2).abs() | |
# loss = (im1 - im2).square() | |
if conf_sigma is not None: | |
loss = loss * 2**0.5 / (conf_sigma + EPS) + (conf_sigma + EPS).log() | |
if mask is not None: | |
mask = mask.expand_as(loss) | |
loss = (loss * mask).sum() / mask.sum() | |
else: | |
loss = loss.mean() | |
return loss | |
# gaussian likelihood version, What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision? | |
# also used in the mask-aware vgg loss | |
def photometric_loss(im1, im2, mask=None, conf_sigma=None): | |
# loss = torch.nn.functional.mse_loss(im1, im2, reduce='none') | |
loss = (im1 - im2).square() | |
if conf_sigma is not None: | |
loss = loss / (2 * conf_sigma**2 + EPS) + (conf_sigma + EPS).log() | |
if mask is not None: | |
mask = mask.expand_as(loss) | |
loss = (loss * mask).sum() / mask.sum() | |
else: | |
loss = loss.mean() | |
return loss | |
class E3DGELossClass(torch.nn.Module): | |
def __init__(self, device, opt) -> None: | |
super().__init__() | |
self.opt = opt | |
self.device = device | |
self.criterionImg = { | |
'mse': torch.nn.MSELoss(), | |
'l1': torch.nn.L1Loss(), | |
'charbonnier': CharbonnierLoss(), | |
}[opt.color_criterion] | |
self.criterion_latent = { | |
'mse': torch.nn.MSELoss(), | |
'l1': torch.nn.L1Loss(), | |
'vae': feature_vae_loss | |
}[opt.latent_criterion] | |
# self.criterionLPIPS = LPIPS(net_type='alex', device=device).eval() | |
if opt.lpips_lambda > 0: | |
self.criterionLPIPS = LPIPSLoss(loss_weight=opt.lpips_lambda) | |
# self.criterionLPIPS = torch.nn.MSELoss() | |
if opt.id_lambda > 0: | |
self.criterionID = IDLoss(device=device).eval() | |
self.id_loss_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) | |
# define 3d rec loss, for occupancy | |
# self.criterion3d_rec = torch.nn.SmoothL1Loss(reduction='none') | |
# self.criterion_alpha = torch.nn.SmoothL1Loss() | |
# self.criterion3d_rec = torch.nn.MSELoss(reduction='none') | |
self.criterion_alpha = torch.nn.L1Loss() | |
if self.opt.xyz_lambda > 0: | |
# self.criterion_xyz = torch.nn.SmoothL1Loss() | |
self.criterion_xyz = torch.nn.L1Loss() # follow LION, but noisy xyz here... | |
if self.opt.depth_lambda > 0: | |
# ! this depth loss not converging, no idea why | |
self.criterion3d_rec = ScaleAndShiftInvariantLoss(alpha=0.5, | |
scales=1) | |
else: | |
self.criterion3d_rec = torch.nn.SmoothL1Loss(reduction='none') | |
# self.silog_loss = SILogLoss() | |
if self.opt.lambda_opa_reg > 0: | |
# self.beta_mvp_dist = torch.distributions.beta.Beta(torch.tensor(0.5, device=device), torch.tensor(0.5, device=device)) | |
# self.beta_mvp_base_dist = torch.distributions.beta.Beta(torch.tensor(10, device=device), torch.tensor(0.5, device=device)) # force close to 1 for base | |
# self.beta_mvp_base_dist = torch.distributions.beta.Beta(torch.tensor(0.6, device=device), torch.tensor(0.2, device=device)) # force close to 1 for base | |
self.beta_mvp_base_dist = torch.distributions.beta.Beta(torch.tensor(0.5, device=device), torch.tensor(0.25, device=device)) # force close to 1 for base | |
logger.log('init loss class finished', ) | |
def calc_scale_invariant_depth_loss(self, pred_depth: torch.Tensor, | |
gt_depth: torch.Tensor, | |
gt_depth_mask: torch.Tensor): | |
"""apply 3d shape reconstruction supervision. Basically supervise the depth with L1 loss | |
""" | |
shape_loss_dict = {} | |
assert gt_depth_mask is not None | |
shape_loss = self.criterion3d_rec(pred_depth, gt_depth, gt_depth_mask) | |
# if shape_loss > 0.2: # hinge loss, avoid ood gradient | |
# shape_loss = torch.zeros_like(shape_loss) | |
# else: | |
shape_loss = shape_loss.clamp(0.04) # g-buffer depth is very noisy | |
shape_loss *= self.opt.depth_lambda | |
shape_loss_dict['loss_depth'] = shape_loss | |
# shape_loss_dict['depth_fgratio'] = gt_depth_mask.mean() | |
# return l_si, shape_loss_dict | |
return shape_loss, shape_loss_dict | |
def calc_depth_loss(self, pred_depth: torch.Tensor, gt_depth: torch.Tensor, | |
gt_depth_mask: torch.Tensor): | |
"""apply 3d shape reconstruction supervision. Basically supervise the depth with L1 loss | |
""" | |
shape_loss_dict = {} | |
shape_loss = self.criterion3d_rec(pred_depth, gt_depth) | |
assert gt_depth_mask is not None | |
shape_loss *= gt_depth_mask | |
shape_loss = shape_loss.sum() / gt_depth_mask.sum() | |
# else: | |
# shape_loss /= pred_depth.numel() | |
# l_si = self.silog_loss(pred_depth, gt_depth, mask=None, interpolate=True, return_interpolated=False) | |
# l_si *= self.opt.depth_lambda | |
# shape_loss_dict['loss_depth'] = l_si | |
shape_loss_dict['loss_depth'] = shape_loss.clamp( | |
min=0, max=0.1) * self.opt.depth_lambda | |
# shape_loss_dict['loss_depth'] = shape_loss.clamp( | |
# min=0, max=0.5) * self.opt.depth_lambda | |
# return l_si, shape_loss_dict | |
return shape_loss, shape_loss_dict | |
def calc_alpha_loss(self, pred_alpha, gt_depth_mask): | |
# return self.criterionImg(alpha, gt_depth_mask.float()) | |
if gt_depth_mask.ndim == 3: | |
gt_depth_mask = gt_depth_mask.unsqueeze(1) | |
if gt_depth_mask.shape[1] == 3: | |
gt_depth_mask = gt_depth_mask[:, 0:1, ...] # B 1 H W | |
assert pred_alpha.shape == gt_depth_mask.shape | |
alpha_loss = self.criterion_alpha(pred_alpha, gt_depth_mask) | |
# st() | |
return alpha_loss | |
def calc_mask_mse_loss( | |
self, | |
input, | |
gt, | |
gt_depth_mask, | |
# conf_sigma=None, | |
conf_sigma_l1=None, | |
# conf_sigma_percl=None, | |
use_fg_ratio=False): | |
if gt_depth_mask.ndim == 3: | |
gt_depth_mask = gt_depth_mask.unsqueeze(1).repeat_interleave(3, 1) | |
else: | |
assert gt_depth_mask.shape == input.shape | |
gt_depth_mask = gt_depth_mask.float() | |
if conf_sigma_l1 is None: | |
rec_loss = torch.nn.functional.mse_loss( | |
input.float(), gt.float(), | |
reduction='none') # 'sum' already divide by batch size n | |
else: | |
rec_loss = photometric_loss( | |
input, gt, gt_depth_mask, conf_sigma_l1 | |
) # ! only cauclate laplace on the foreground, or bg confidence low, large gradient. | |
return rec_loss | |
# rec_loss = torch.nn.functional.l1_loss( # for laplace loss | |
# input.float(), gt.float(), | |
# reduction='none') # 'sum' already divide by batch size n | |
# gt_depth_mask = torch.ones_like(gt_depth_mask) # ! DEBUGGING | |
# if conf_sigma is not None: # from unsup3d, but a L2 version | |
# rec_loss = rec_loss * 2**0.5 / (conf_sigma + EPS) + (conf_sigma + | |
# EPS).log() | |
# return rec_loss.mean() | |
# rec_loss = torch.exp(-(rec_loss * 2**0.5 / (conf_sigma + EPS))) * 1/(conf_sigma + | |
# EPS) / (2**0.5) | |
fg_size = gt_depth_mask.sum() | |
# fg_ratio = fg_size / torch.ones_like(gt_depth_mask).sum() if use_fg_ratio else 1 | |
fg_loss = rec_loss * gt_depth_mask | |
fg_loss = fg_loss.sum() / fg_size # * fg_ratio | |
if self.opt.bg_lamdba > 0: | |
bg_loss = rec_loss * (1 - gt_depth_mask) | |
bg_loss = bg_loss.sum() / (1 - gt_depth_mask).sum() | |
rec_loss = fg_loss + bg_loss * self.opt.bg_lamdba | |
else: | |
rec_loss = fg_loss | |
return rec_loss | |
def calc_2d_rec_loss( | |
self, | |
input, | |
gt, | |
depth_fg_mask, | |
test_mode=True, | |
step=1, | |
ignore_lpips=False, | |
# conf_sigma=None, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None, | |
pred_alpha=None, | |
): | |
opt = self.opt | |
loss_dict = {} | |
# logger.log(test_mode) | |
# logger.log(input.min(), input.max(), gt.min(), gt.max()) | |
if test_mode or not opt.fg_mse: | |
rec_loss = self.criterionImg(input, gt) | |
else: | |
rec_loss = self.calc_mask_mse_loss( | |
input, | |
gt, | |
depth_fg_mask, | |
conf_sigma_l1=conf_sigma_l1, | |
) | |
# conf_sigma_percl=conf_sigma_percl) | |
# conf_sigma) | |
# if step == 300: | |
# st() | |
if opt.lpips_lambda > 0 and step >= opt.lpips_delay_iter and not ignore_lpips: # tricky solution to avoid NAN in LPIPS loss | |
# with torch.autocast(device_type='cuda', | |
# dtype=torch.float16, | |
# enabled=False): | |
# if test_mode or not opt.fg_mse: # no need to calculate background lpips for ease of computation | |
# inp_for_lpips = input * pred_alpha + torch.ones_like(input) * (1-pred_alpha) | |
# gt_for_lpips = gt * depth_fg_mask + torch.ones_like(gt) * (1-depth_fg_mask) | |
inp_for_lpips = input * pred_alpha | |
gt_for_lpips = gt * depth_fg_mask | |
width = input.shape[-1] | |
if width == 192: # triplane here | |
lpips_loss = self.criterionLPIPS( # loss on 128x128 center crop | |
inp_for_lpips[:, :, width//2-64:width//2+64, width//2-64:width//2+64], | |
gt_for_lpips[:, :, width//2-64:width//2+64, width//2-64:width//2+64], | |
conf_sigma_percl=conf_sigma_percl, | |
) | |
elif width >256: | |
# elif width >192: | |
# lpips_loss = self.criterionLPIPS( | |
# F.interpolate(inp_for_lpips, (256,256), mode='bilinear'), | |
# F.interpolate(gt_for_lpips, (256,256), mode='bilinear'), | |
# conf_sigma_percl=conf_sigma_percl, | |
# ) | |
# patch = 80 | |
# patch = 128 | |
patch = 144 | |
middle_point = width // 2 | |
lpips_loss = self.criterionLPIPS( # loss on 128x128 center crop | |
inp_for_lpips[:, :, middle_point-patch:middle_point+patch, middle_point-patch:middle_point+patch], | |
gt_for_lpips[:, :, middle_point-patch:middle_point+patch, middle_point-patch:middle_point+patch], | |
conf_sigma_percl=conf_sigma_percl, | |
) | |
else: # directly supervise when <= 256 | |
# ! add foreground mask | |
assert pred_alpha is not None | |
lpips_loss = self.criterionLPIPS( | |
inp_for_lpips, | |
gt_for_lpips, | |
# conf_sigma_percl=conf_sigma_percl, | |
) | |
# else: # fg lpips | |
# assert depth_fg_mask.shape == input.shape | |
# lpips_loss = self.criterionLPIPS( | |
# input.contiguous() * depth_fg_mask, | |
# gt.contiguous() * depth_fg_mask).mean() | |
else: | |
lpips_loss = torch.tensor(0., device=input.device) | |
if opt.ssim_lambda > 0: | |
loss_ssim = self.ssim_loss(input, gt) #? | |
else: | |
loss_ssim = torch.tensor(0., device=input.device) | |
loss_psnr = self.psnr((input / 2 + 0.5), (gt / 2 + 0.5), 1.0) | |
if opt.id_lambda > 0: | |
loss_id = self._calc_loss_id(input, gt) | |
else: | |
loss_id = torch.tensor(0., device=input.device) | |
if opt.l1_lambda > 0: | |
loss_l1 = F.l1_loss(input, gt) | |
else: | |
loss_l1 = torch.tensor(0., device=input.device) | |
# loss = rec_loss * opt.l2_lambda + lpips_loss * opt.lpips_lambda + loss_id * opt.id_lambda + loss_ssim * opt.ssim_lambda | |
rec_loss = rec_loss * opt.l2_lambda | |
loss = rec_loss + lpips_loss + loss_id * opt.id_lambda + loss_ssim * opt.ssim_lambda + opt.l1_lambda * loss_l1 | |
# if return_dict: | |
loss_dict['loss_l2'] = rec_loss | |
loss_dict['loss_id'] = loss_id | |
loss_dict['loss_lpips'] = lpips_loss | |
loss_dict['loss'] = loss | |
loss_dict['loss_ssim'] = loss_ssim | |
# metrics to report, not involved in training | |
loss_dict['mae'] = loss_l1 | |
loss_dict['PSNR'] = loss_psnr | |
loss_dict['SSIM'] = 1 - loss_ssim # Todo | |
loss_dict['ID_SIM'] = 1 - loss_id | |
return loss, loss_dict | |
def calc_shape_rec_loss( | |
self, | |
pred_shape: dict, | |
gt_shape: dict, | |
device, | |
): | |
"""apply 3d shape reconstruction supervision. Basically supervise the densities with L1 loss | |
Args: | |
pred_shape (dict): dict contains reconstructed shape information | |
gt_shape (dict): dict contains gt shape information | |
supervise_sdf (bool, optional): whether supervise sdf rec. Defaults to True. | |
supervise_surface_normal (bool, optional): whether supervise surface rec. Defaults to False. | |
Returns: | |
dict: shape reconstruction loss | |
""" | |
shape_loss_dict = {} | |
shape_loss = 0 | |
# assert supervise_sdf or supervise_surface_normal, 'should at least supervise one types of shape reconstruction' | |
# todo, add weights | |
if self.opt.shape_uniform_lambda > 0: | |
shape_loss_dict['coarse'] = self.criterion3d_rec( | |
pred_shape['coarse_densities'].squeeze(), | |
gt_shape['coarse_densities'].squeeze()) | |
shape_loss += shape_loss_dict[ | |
'coarse'] * self.opt.shape_uniform_lambda | |
if self.opt.shape_importance_lambda > 0: | |
shape_loss_dict['fine'] = self.criterion3d_rec( | |
pred_shape['fine_densities'].squeeze(), # ? how to supervise | |
gt_shape['fine_densities'].squeeze()) | |
shape_loss += shape_loss_dict[ | |
'fine'] * self.opt.shape_importance_lambda | |
loss_depth = self.criterion_alpha(pred_shape['image_depth'], | |
gt_shape['image_depth']) | |
shape_loss += loss_depth * self.opt.shape_depth_lambda | |
shape_loss_dict.update(dict(loss_depth=loss_depth)) | |
# TODO, add on surface pts supervision ? | |
return shape_loss, shape_loss_dict | |
def psnr(self, input, target, max_val): | |
return kornia.metrics.psnr(input, target, max_val) | |
# @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False) | |
def ssim_loss(self, img1, img2, window_size=11, size_average=True): | |
channel = img1.size(-3) | |
window = create_window(window_size, channel) | |
if img1.is_cuda: | |
window = window.cuda(img1.get_device()) | |
window = window.type_as(img1) | |
return 1 - _ssim(img1, img2, window, window_size, channel, size_average) | |
def forward(self, | |
pred, | |
gt, | |
test_mode=True, | |
step=1, | |
return_fg_mask=False, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None, | |
ignore_kl=False, | |
ignore_lpips=False, | |
*args, | |
**kwargs): | |
with torch.autocast(device_type='cuda', | |
dtype=torch.float16, | |
enabled=False): | |
loss = torch.tensor(0., device=self.device) | |
loss_dict = {} | |
if 'image_mask' in pred: | |
pred_alpha = pred['image_mask'] # B 1 H W | |
else: | |
N, _, H, W = pred['image_depth'].shape | |
pred_alpha = pred['weights_samples'].permute(0, 2, 1).reshape( | |
N, 1, H, W) | |
# balance rec_loss with logvar | |
# if 'depth_mask' in gt: | |
if self.opt.online_mask: | |
# https://github.com/elliottwu/unsup3d/blob/dc961410d61684561f19525c2f7e9ee6f4dacb91/unsup3d/model.py#L193 | |
margin = (self.opt.max_depth - self.opt.min_depth) / 2 | |
fg_mask = (pred['image_depth'] | |
< self.opt.max_depth + margin).float() # B 1 H W | |
fg_mask = fg_mask.repeat_interleave(3, 1).float() | |
else: | |
if 'depth_mask' in gt: | |
if gt['depth_mask'].shape[1] != 1: | |
fg_mask = gt['depth_mask'].unsqueeze(1) | |
else: | |
fg_mask = gt['depth_mask'] | |
fg_mask = fg_mask.repeat_interleave( | |
3, 1).float() | |
else: | |
fg_mask = None | |
loss_2d, loss_2d_dict = self.calc_2d_rec_loss( | |
pred['image_raw'], | |
gt['img'], | |
fg_mask, | |
test_mode=test_mode, | |
step=step, | |
ignore_lpips=ignore_lpips, | |
conf_sigma_l1=conf_sigma_l1, | |
conf_sigma_percl=conf_sigma_percl, | |
pred_alpha=pred_alpha, | |
) | |
# ignore_lpips=self.opt.fg_mse) | |
if self.opt.kl_lambda > 0 and not ignore_kl: | |
# assert 'posterior' in pred, 'logvar' in pred | |
assert 'posterior' in pred | |
if self.opt.kl_anneal: | |
kl_lambda = kl_coeff( | |
step=step, | |
constant_step=5e3, # 1w steps | |
total_step=25e3, # 5w steps in total | |
min_kl_coeff=max(1e-9, self.opt.kl_lambda / 1e4), | |
max_kl_coeff=self.opt.kl_lambda) | |
loss_dict['kl_lambda'] = kl_lambda | |
else: | |
loss_dict['kl_lambda'] = torch.tensor( | |
self.opt.kl_lambda, device=dist_util.dev()) | |
if self.opt.pt_ft_kl: | |
pt_kl, ft_kl = pred['posterior'].kl(pt_ft_separate=True) | |
kl_batch = pt_kl.shape[0] | |
# loss_dict['kl_loss_pt'] = pt_kl.sum() * loss_dict['kl_lambda'] * 0.01 / kl_batch | |
loss_dict['kl_loss_pt'] = pt_kl.sum() * loss_dict['kl_lambda'] * 0 # no compression at all. | |
loss_dict['kl_loss_ft'] = ft_kl.sum() * loss_dict['kl_lambda'] / kl_batch | |
loss = loss + loss_dict['kl_loss_pt'] + loss_dict['kl_loss_ft'] | |
loss_dict['latent_mu_pt'] = pred['posterior'].mean[:, :3].mean() | |
loss_dict['latent_std_pt'] = pred['posterior'].std[:, :3].mean() | |
loss_dict['latent_mu_ft'] = pred['posterior'].mean[:, 3:].mean() | |
loss_dict['latent_std_ft'] = pred['posterior'].std[:, 3:].mean() | |
elif self.opt.ft_kl: | |
ft_kl = pred['posterior'].kl(ft_separate=True) | |
kl_batch = ft_kl.shape[0] | |
loss_dict['kl_loss_ft'] = ft_kl.sum() * loss_dict['kl_lambda'] / kl_batch | |
loss = loss + loss_dict['kl_loss_ft'] | |
loss_dict['latent_mu_ft'] = pred['posterior'].mean[:, :].square().mean().float().detach() | |
loss_dict['latent_std_ft'] = pred['posterior'].std[:, :].mean().float().detach() | |
else: | |
kl_loss = pred['posterior'].kl() | |
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] | |
loss_dict['kl_loss'] = kl_loss * loss_dict['kl_lambda'] | |
loss += loss_dict['kl_loss'] | |
loss_dict['latent_mu'] = pred['posterior'].mean.mean() | |
loss_dict['latent_std'] = pred['posterior'].std.mean() | |
# nll_loss = loss_2d / torch.exp(pred['logvar']) + pred['logvar'] # nll_loss | |
nll_loss = loss_2d | |
loss += nll_loss | |
loss_dict.update(dict(nll_loss=nll_loss)) | |
# loss_dict['latent_mu'] = pred['latent_normalized'].mean() | |
# loss_dict['latent_max'] = pred['latent_normalized'].max() | |
# loss_dict['latent_min'] = pred['latent_normalized'].min() | |
# loss_dict['latent_std'] = pred['latent_normalized'].std() | |
# pred[ | |
# 'latent_normalized_2Ddiffusion'].mean() | |
# loss_dict['latent_std'] = pred[ | |
# 'latent_normalized_2Ddiffusion'].std() | |
# loss_dict['latent_max'] = pred[ | |
# 'latent_normalized_2Ddiffusion'].max() | |
# loss_dict['latent_min'] = pred[ | |
# 'latent_normalized_2Ddiffusion'].min() | |
else: | |
loss += loss_2d | |
# if 'image_sr' in pred and pred['image_sr'].shape==gt['img_sr']: | |
if 'image_sr' in pred: | |
if 'depth_mask_sr' in gt: | |
depth_mask_sr = gt['depth_mask_sr'].unsqueeze( | |
1).repeat_interleave(3, 1).float() | |
else: | |
depth_mask_sr = None | |
loss_sr, loss_sr_dict = self.calc_2d_rec_loss( | |
pred['image_sr'], | |
gt['img_sr'], | |
depth_fg_mask=depth_mask_sr, | |
# test_mode=test_mode, | |
test_mode=True, | |
step=step) | |
loss_sr_lambda = 1 | |
if step < self.opt.sr_delay_iter: | |
loss_sr_lambda = 0 | |
loss += loss_sr * loss_sr_lambda | |
for k, v in loss_sr_dict.items(): | |
loss_dict['sr_' + k] = v * loss_sr_lambda | |
if self.opt.depth_lambda > 0: # TODO, switch to scale-agnostic depth loss | |
assert 'depth' in gt | |
pred_depth = pred['image_depth'] | |
if pred_depth.ndim == 4: | |
pred_depth = pred_depth.squeeze(1) # B H W | |
# _, shape_loss_dict = self.calc_depth_loss( | |
# pred_depth, gt['depth'], fg_mask[:, 0, ...]) | |
_, shape_loss_dict = self.calc_scale_invariant_depth_loss( | |
pred_depth, gt['depth'], fg_mask[:, 0, ...]) | |
loss += shape_loss_dict['loss_depth'] | |
loss_dict.update(shape_loss_dict) | |
# if self.opt.latent_lambda > 0: # make sure the latent suits diffusion learning | |
# latent_mu = pred['latent'].mean() | |
# loss_latent = self.criterion_latent( | |
# latent_mu, torch.zeros_like( | |
# latent_mu)) # only regularize the mean value here | |
# loss_dict['loss_latent'] = loss_latent | |
# loss += loss_latent * self.opt.latent_lambda | |
if self.opt.alpha_lambda > 0 and 'image_depth' in pred: | |
loss_alpha = self.calc_alpha_loss(pred_alpha, fg_mask) | |
loss_dict['loss_alpha'] = loss_alpha * self.opt.alpha_lambda | |
loss += loss_alpha * self.opt.alpha_lambda | |
if self.opt.depth_smoothness_lambda > 0: | |
loss_depth_smoothness = depth_smoothness_loss( | |
pred_alpha, | |
pred['image_depth']) * self.opt.depth_smoothness_lambda | |
loss_dict['loss_depth_smoothness'] = loss_depth_smoothness | |
loss += loss_depth_smoothness | |
loss_2d_dict['all_loss'] = loss | |
loss_dict.update(loss_2d_dict) | |
# if return_fg_mask: | |
return loss, loss_dict, fg_mask | |
# else: | |
# return loss, loss_dict | |
def _calc_loss_id(self, input, gt): | |
if input.shape[-1] != 256: | |
arcface_input = self.id_loss_pool(input) | |
id_loss_gt = self.id_loss_pool(gt) | |
else: | |
arcface_input = input | |
id_loss_gt = gt | |
loss_id, _, _ = self.criterionID(arcface_input, id_loss_gt, id_loss_gt) | |
return loss_id | |
def calc_2d_rec_loss_misaligned(self, input, gt): | |
"""id loss + vgg loss | |
Args: | |
input (_type_): _description_ | |
gt (_type_): _description_ | |
depth_mask (_type_): _description_ | |
test_mode (bool, optional): _description_. Defaults to True. | |
""" | |
opt = self.opt | |
loss_dict = {} | |
if opt.lpips_lambda > 0: | |
with torch.autocast( | |
device_type='cuda', dtype=torch.float16, | |
enabled=False): # close AMP for lpips to avoid nan | |
lpips_loss = self.criterionLPIPS(input, gt) | |
else: | |
lpips_loss = torch.tensor(0., device=input.device) | |
if opt.id_lambda > 0: | |
loss_id = self._calc_loss_id(input, gt) | |
else: | |
loss_id = torch.tensor(0., device=input.device) | |
loss_dict['loss_id_real'] = loss_id | |
loss_dict['loss_lpips_real'] = lpips_loss | |
loss = lpips_loss * opt.lpips_lambda + loss_id * opt.id_lambda | |
return loss, loss_dict | |
class E3DGE_with_AdvLoss(E3DGELossClass): | |
# adapted from sgm/modules/autoencoding/losses/discriminator_loss.py | |
def __init__( | |
self, | |
device, | |
opt, | |
discriminator_config: Optional[Dict] = None, | |
disc_num_layers: int = 3, | |
disc_in_channels: int = 3, | |
disc_start: int = 0, | |
disc_loss: str = "hinge", | |
disc_factor: float = 1.0, | |
disc_weight: float = 1.0, | |
regularization_weights: Union[None, Dict[str, float]] = None, | |
dtype=torch.float32, | |
# additional_log_keys: Optional[List[str]] = None, | |
) -> None: | |
super().__init__( | |
device, | |
opt, | |
) | |
# ! initialize GAN loss | |
discriminator_config = default( | |
discriminator_config, | |
{ | |
"target": | |
"nsr.losses.disc.NLayerDiscriminator", | |
"params": { | |
"input_nc": disc_in_channels, | |
"n_layers": disc_num_layers, | |
"use_actnorm": False, | |
}, | |
}, | |
) | |
self.discriminator = instantiate_from_config( | |
discriminator_config).apply(weights_init) | |
self.discriminator_iter_start = disc_start | |
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss | |
self.disc_factor = disc_factor | |
self.discriminator_weight = disc_weight # self.regularization_weights = default(regularization_weights, {}) | |
# self.forward_keys = [ | |
# "optimizer_idx", | |
# "global_step", | |
# "last_layer", | |
# "split", | |
# "regularization_log", | |
# ] | |
# self.additional_log_keys = set(default(additional_log_keys, [])) | |
# self.additional_log_keys.update(set( | |
# self.regularization_weights.keys())) | |
def get_trainable_parameters(self) -> Iterator[nn.Parameter]: | |
return self.discriminator.parameters() | |
def forward(self, | |
pred, | |
gt, | |
behaviour: str, | |
test_mode=True, | |
step=1, | |
return_fg_mask=False, | |
conf_sigma_l1=None, | |
conf_sigma_percl=None, | |
ignore_d_loss=False, | |
*args, | |
**kwargs): | |
# now the GAN part | |
reconstructions = pred['image_raw'] | |
inputs = gt['img'] | |
if behaviour == 'g_step': | |
nll_loss, loss_dict, fg_mask = super().forward( | |
pred, | |
gt, | |
test_mode, | |
step, | |
return_fg_mask, | |
conf_sigma_l1, | |
conf_sigma_percl, | |
*args, | |
**kwargs) | |
# generator update | |
if not ignore_d_loss and (step >= self.discriminator_iter_start or not self.training): | |
logits_fake = self.discriminator(reconstructions.contiguous()) | |
g_loss = -torch.mean(logits_fake) | |
if self.training: | |
d_weight = torch.tensor(self.discriminator_weight) | |
# d_weight = self.calculate_adaptive_weight( | |
# nll_loss, g_loss, last_layer=last_layer) | |
else: | |
d_weight = torch.tensor(1.0) | |
else: | |
d_weight = torch.tensor(0.0) | |
g_loss = torch.tensor(0.0, requires_grad=True) | |
g_loss = g_loss * d_weight * self.disc_factor | |
loss = nll_loss + g_loss | |
# TODO | |
loss_dict.update({ | |
f"loss/g": g_loss.detach().mean(), | |
}) | |
# return loss, log | |
return loss, loss_dict, fg_mask | |
elif behaviour == 'd_step' and not ignore_d_loss: | |
# second pass for discriminator update | |
logits_real = self.discriminator(inputs.contiguous().detach()) | |
logits_fake = self.discriminator( | |
reconstructions.contiguous().detach()) | |
if step >= self.discriminator_iter_start or not self.training: | |
d_loss = self.disc_factor * self.disc_loss( | |
logits_real, logits_fake) | |
else: | |
d_loss = torch.tensor(0.0, requires_grad=True) | |
loss_dict = {} | |
loss_dict.update({ | |
"loss/disc": d_loss.clone().detach().mean(), | |
"logits/real": logits_real.detach().mean(), | |
"logits/fake": logits_fake.detach().mean(), | |
}) | |
return d_loss, loss_dict, None | |
else: | |
raise NotImplementedError(f"Unknown optimizer behaviour {behaviour}") | |