EPS = 1e-7 import kornia from typing import Dict, Iterator, List, Optional, Tuple, Union import torchvision from guided_diffusion import dist_util, logger from pdb import set_trace as st from torch.nn import functional as F import numpy as np import torch import torch.nn as nn import lpips from . import * from .sdfstudio_losses import ScaleAndShiftInvariantLoss from ldm.util import default, instantiate_from_config from .vqperceptual import hinge_d_loss, vanilla_d_loss from torch.autograd import Variable from math import exp def gaussian(window_size, sigma): gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) return gauss / gauss.sum() def create_window(window_size, channel): _1D_window = gaussian(window_size, 1.5).unsqueeze(1) _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) return window def _ssim(img1, img2, window, window_size, channel, size_average=True): mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 C1 = 0.01 ** 2 C2 = 0.03 ** 2 ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) if size_average: return ssim_map.mean() else: return ssim_map.mean(1).mean(1).mean(1) def weights_init(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm") != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) # Main loss function used for ZoeDepth. Copy/paste from AdaBins repo (https://github.com/shariqfarooq123/AdaBins/blob/0952d91e9e762be310bb4cd055cbfe2448c0ce20/loss.py#L7) def extract_key(prediction, key): if isinstance(prediction, dict): return prediction[key] return prediction class SILogLoss(nn.Module): """SILog loss (pixel-wise)""" def __init__(self, beta=0.15): super(SILogLoss, self).__init__() self.name = 'SILog' self.beta = beta def forward(self, input, target, mask=None, interpolate=True, return_interpolated=False): # input = extract_key(input, KEY_OUTPUT) if input.shape[-1] != target.shape[-1] and interpolate: input = nn.functional.interpolate(input, target.shape[-2:], mode='bilinear', align_corners=True) intr_input = input else: intr_input = input if target.ndim == 3: target = target.unsqueeze(1) if mask is not None: if mask.ndim == 3: mask = mask.unsqueeze(1) input = input[mask] target = target[mask] # with torch.amp.autocast(enabled=False): # amp causes NaNs in this loss function alpha = 1e-7 g = torch.log(input + alpha) - torch.log(target + alpha) # n, c, h, w = g.shape # norm = 1/(h*w) # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2 Dg = torch.var(g) + self.beta * torch.pow(torch.mean(g), 2) loss = 10 * torch.sqrt(Dg) if torch.isnan(loss): print("Nan SILog loss") print("input:", input.shape) print("target:", target.shape) print("G", torch.sum(torch.isnan(g))) print("Input min max", torch.min(input), torch.max(input)) print("Target min max", torch.min(target), torch.max(target)) print("Dg", torch.isnan(Dg)) print("loss", torch.isnan(loss)) if not return_interpolated: return loss return loss, intr_input def get_outnorm(x: torch.Tensor, out_norm: str = '') -> torch.Tensor: """ Common function to get a loss normalization value. Can normalize by either the batch size ('b'), the number of channels ('c'), the image size ('i') or combinations ('bi', 'bci', etc) """ # b, c, h, w = x.size() img_shape = x.shape if not out_norm: return 1 norm = 1 if 'b' in out_norm: # normalize by batch size # norm /= b norm /= img_shape[0] if 'c' in out_norm: # normalize by the number of channels # norm /= c norm /= img_shape[-3] if 'i' in out_norm: # normalize by image/map size # norm /= h*w norm /= img_shape[-1] * img_shape[-2] return norm class CharbonnierLoss(torch.nn.Module): """Charbonnier Loss (L1)""" def __init__(self, eps=1e-6, out_norm: str = 'bci'): super(CharbonnierLoss, self).__init__() self.eps = eps self.out_norm = out_norm def forward(self, x, y): norm = get_outnorm(x, self.out_norm) loss = torch.sum(torch.sqrt((x - y).pow(2) + self.eps**2)) return loss * norm def feature_vae_loss(feature): # kld = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0) # feature dim: B C H W mu = feature.mean(1) var = feature.var(1) log_var = torch.log(var) kld = torch.mean(-0.5 * torch.sum(1 + log_var - mu**2 - var, dim=1), dim=0) return kld def kl_coeff(step, total_step, constant_step, min_kl_coeff, max_kl_coeff): # return max(min(max_kl_coeff * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff) kl_lambda = max( min( min_kl_coeff + (max_kl_coeff - min_kl_coeff) * (step - constant_step) / total_step, max_kl_coeff), min_kl_coeff) return torch.tensor(kl_lambda, device=dist_util.dev()) def depth_smoothness_loss(alpha_pred, depth_pred): # from PesonNeRF paper. # all Tensor shape B 1 H W geom_loss = ( alpha_pred[..., :-1] * alpha_pred[..., 1:] * ( depth_pred[..., :-1] - depth_pred[..., 1:] # W dim ).square()).mean() # mean of ([8, 1, 64, 63]) geom_loss += (alpha_pred[..., :-1, :] * alpha_pred[..., 1:, :] * (depth_pred[..., :-1, :] - depth_pred[..., 1:, :]).square() ).mean() # H dim, ([8, 1, 63, 64]) return geom_loss # https://github.com/elliottwu/unsup3d/blob/master/unsup3d/networks.py#L140 class LPIPSLoss(torch.nn.Module): def __init__( self, loss_weight=1.0, use_input_norm=True, range_norm=True, # n1p1_input=True, ): super(LPIPSLoss, self).__init__() # self.perceptual = lpips.LPIPS(net="alex", spatial=False).eval() self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() self.loss_weight = loss_weight self.use_input_norm = use_input_norm self.range_norm = range_norm # if self.use_input_norm: # # the mean is for image with range [0, 1] # self.register_buffer( # 'mean', # torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) # # the std is for image with range [0, 1] # self.register_buffer( # 'std', # torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) def forward(self, pred, target, conf_sigma_percl=None): # st() # ! add large image support, only sup 128x128 patch lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) return self.loss_weight * lpips_loss.mean() # mask-aware perceptual loss class PerceptualLoss(nn.Module): def __init__(self, requires_grad=False): super(PerceptualLoss, self).__init__() mean_rgb = torch.FloatTensor([0.485, 0.456, 0.406]) std_rgb = torch.FloatTensor([0.229, 0.224, 0.225]) self.register_buffer('mean_rgb', mean_rgb) self.register_buffer('std_rgb', std_rgb) vgg_pretrained_features = torchvision.models.vgg16( pretrained=True).features self.slice1 = nn.Sequential() self.slice2 = nn.Sequential() self.slice3 = nn.Sequential() self.slice4 = nn.Sequential() for x in range(4): self.slice1.add_module(str(x), vgg_pretrained_features[x]) for x in range(4, 9): self.slice2.add_module(str(x), vgg_pretrained_features[x]) for x in range(9, 16): self.slice3.add_module(str(x), vgg_pretrained_features[x]) for x in range(16, 23): self.slice4.add_module(str(x), vgg_pretrained_features[x]) if not requires_grad: for param in self.parameters(): param.requires_grad = False def normalize(self, x): out = x / 2 + 0.5 out = (out - self.mean_rgb.view(1, 3, 1, 1)) / self.std_rgb.view( 1, 3, 1, 1) return out def __call__(self, im1, im2, mask=None, conf_sigma=None): im = torch.cat([im1, im2], 0) im = self.normalize(im) # normalize input ## compute features feats = [] f = self.slice1(im) feats += [torch.chunk(f, 2, dim=0)] f = self.slice2(f) feats += [torch.chunk(f, 2, dim=0)] f = self.slice3(f) feats += [torch.chunk(f, 2, dim=0)] f = self.slice4(f) feats += [torch.chunk(f, 2, dim=0)] losses = [] for f1, f2 in feats[2:3]: # use relu3_3 features only loss = (f1 - f2)**2 if conf_sigma is not None: loss = loss / (2 * conf_sigma**2 + EPS) + (conf_sigma + EPS).log() if mask is not None: b, c, h, w = loss.shape _, _, hm, wm = mask.shape sh, sw = hm // h, wm // w mask0 = nn.functional.avg_pool2d(mask, kernel_size=(sh, sw), stride=(sh, sw)).expand_as(loss) loss = (loss * mask0).sum() / mask0.sum() else: loss = loss.mean() losses += [loss] return sum(losses) # add confidence support, unsup3d version def photometric_loss_laplace(im1, im2, mask=None, conf_sigma=None): loss = (im1 - im2).abs() # loss = (im1 - im2).square() if conf_sigma is not None: loss = loss * 2**0.5 / (conf_sigma + EPS) + (conf_sigma + EPS).log() if mask is not None: mask = mask.expand_as(loss) loss = (loss * mask).sum() / mask.sum() else: loss = loss.mean() return loss # gaussian likelihood version, What Uncertainties Do We Need in Bayesian Deep Learning for Computer Vision? # also used in the mask-aware vgg loss def photometric_loss(im1, im2, mask=None, conf_sigma=None): # loss = torch.nn.functional.mse_loss(im1, im2, reduce='none') loss = (im1 - im2).square() if conf_sigma is not None: loss = loss / (2 * conf_sigma**2 + EPS) + (conf_sigma + EPS).log() if mask is not None: mask = mask.expand_as(loss) loss = (loss * mask).sum() / mask.sum() else: loss = loss.mean() return loss class E3DGELossClass(torch.nn.Module): def __init__(self, device, opt) -> None: super().__init__() self.opt = opt self.device = device self.criterionImg = { 'mse': torch.nn.MSELoss(), 'l1': torch.nn.L1Loss(), 'charbonnier': CharbonnierLoss(), }[opt.color_criterion] self.criterion_latent = { 'mse': torch.nn.MSELoss(), 'l1': torch.nn.L1Loss(), 'vae': feature_vae_loss }[opt.latent_criterion] # self.criterionLPIPS = LPIPS(net_type='alex', device=device).eval() if opt.lpips_lambda > 0: self.criterionLPIPS = LPIPSLoss(loss_weight=opt.lpips_lambda) # self.criterionLPIPS = torch.nn.MSELoss() if opt.id_lambda > 0: self.criterionID = IDLoss(device=device).eval() self.id_loss_pool = torch.nn.AdaptiveAvgPool2d((256, 256)) # define 3d rec loss, for occupancy # self.criterion3d_rec = torch.nn.SmoothL1Loss(reduction='none') # self.criterion_alpha = torch.nn.SmoothL1Loss() # self.criterion3d_rec = torch.nn.MSELoss(reduction='none') self.criterion_alpha = torch.nn.L1Loss() if self.opt.xyz_lambda > 0: # self.criterion_xyz = torch.nn.SmoothL1Loss() self.criterion_xyz = torch.nn.L1Loss() # follow LION, but noisy xyz here... if self.opt.depth_lambda > 0: # ! this depth loss not converging, no idea why self.criterion3d_rec = ScaleAndShiftInvariantLoss(alpha=0.5, scales=1) else: self.criterion3d_rec = torch.nn.SmoothL1Loss(reduction='none') # self.silog_loss = SILogLoss() if self.opt.lambda_opa_reg > 0: # self.beta_mvp_dist = torch.distributions.beta.Beta(torch.tensor(0.5, device=device), torch.tensor(0.5, device=device)) # self.beta_mvp_base_dist = torch.distributions.beta.Beta(torch.tensor(10, device=device), torch.tensor(0.5, device=device)) # force close to 1 for base # self.beta_mvp_base_dist = torch.distributions.beta.Beta(torch.tensor(0.6, device=device), torch.tensor(0.2, device=device)) # force close to 1 for base self.beta_mvp_base_dist = torch.distributions.beta.Beta(torch.tensor(0.5, device=device), torch.tensor(0.25, device=device)) # force close to 1 for base logger.log('init loss class finished', ) def calc_scale_invariant_depth_loss(self, pred_depth: torch.Tensor, gt_depth: torch.Tensor, gt_depth_mask: torch.Tensor): """apply 3d shape reconstruction supervision. Basically supervise the depth with L1 loss """ shape_loss_dict = {} assert gt_depth_mask is not None shape_loss = self.criterion3d_rec(pred_depth, gt_depth, gt_depth_mask) # if shape_loss > 0.2: # hinge loss, avoid ood gradient # shape_loss = torch.zeros_like(shape_loss) # else: shape_loss = shape_loss.clamp(0.04) # g-buffer depth is very noisy shape_loss *= self.opt.depth_lambda shape_loss_dict['loss_depth'] = shape_loss # shape_loss_dict['depth_fgratio'] = gt_depth_mask.mean() # return l_si, shape_loss_dict return shape_loss, shape_loss_dict def calc_depth_loss(self, pred_depth: torch.Tensor, gt_depth: torch.Tensor, gt_depth_mask: torch.Tensor): """apply 3d shape reconstruction supervision. Basically supervise the depth with L1 loss """ shape_loss_dict = {} shape_loss = self.criterion3d_rec(pred_depth, gt_depth) assert gt_depth_mask is not None shape_loss *= gt_depth_mask shape_loss = shape_loss.sum() / gt_depth_mask.sum() # else: # shape_loss /= pred_depth.numel() # l_si = self.silog_loss(pred_depth, gt_depth, mask=None, interpolate=True, return_interpolated=False) # l_si *= self.opt.depth_lambda # shape_loss_dict['loss_depth'] = l_si shape_loss_dict['loss_depth'] = shape_loss.clamp( min=0, max=0.1) * self.opt.depth_lambda # shape_loss_dict['loss_depth'] = shape_loss.clamp( # min=0, max=0.5) * self.opt.depth_lambda # return l_si, shape_loss_dict return shape_loss, shape_loss_dict @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False) def calc_alpha_loss(self, pred_alpha, gt_depth_mask): # return self.criterionImg(alpha, gt_depth_mask.float()) if gt_depth_mask.ndim == 3: gt_depth_mask = gt_depth_mask.unsqueeze(1) if gt_depth_mask.shape[1] == 3: gt_depth_mask = gt_depth_mask[:, 0:1, ...] # B 1 H W assert pred_alpha.shape == gt_depth_mask.shape alpha_loss = self.criterion_alpha(pred_alpha, gt_depth_mask) # st() return alpha_loss @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False) def calc_mask_mse_loss( self, input, gt, gt_depth_mask, # conf_sigma=None, conf_sigma_l1=None, # conf_sigma_percl=None, use_fg_ratio=False): if gt_depth_mask.ndim == 3: gt_depth_mask = gt_depth_mask.unsqueeze(1).repeat_interleave(3, 1) else: assert gt_depth_mask.shape == input.shape gt_depth_mask = gt_depth_mask.float() if conf_sigma_l1 is None: rec_loss = torch.nn.functional.mse_loss( input.float(), gt.float(), reduction='none') # 'sum' already divide by batch size n else: rec_loss = photometric_loss( input, gt, gt_depth_mask, conf_sigma_l1 ) # ! only cauclate laplace on the foreground, or bg confidence low, large gradient. return rec_loss # rec_loss = torch.nn.functional.l1_loss( # for laplace loss # input.float(), gt.float(), # reduction='none') # 'sum' already divide by batch size n # gt_depth_mask = torch.ones_like(gt_depth_mask) # ! DEBUGGING # if conf_sigma is not None: # from unsup3d, but a L2 version # rec_loss = rec_loss * 2**0.5 / (conf_sigma + EPS) + (conf_sigma + # EPS).log() # return rec_loss.mean() # rec_loss = torch.exp(-(rec_loss * 2**0.5 / (conf_sigma + EPS))) * 1/(conf_sigma + # EPS) / (2**0.5) fg_size = gt_depth_mask.sum() # fg_ratio = fg_size / torch.ones_like(gt_depth_mask).sum() if use_fg_ratio else 1 fg_loss = rec_loss * gt_depth_mask fg_loss = fg_loss.sum() / fg_size # * fg_ratio if self.opt.bg_lamdba > 0: bg_loss = rec_loss * (1 - gt_depth_mask) bg_loss = bg_loss.sum() / (1 - gt_depth_mask).sum() rec_loss = fg_loss + bg_loss * self.opt.bg_lamdba else: rec_loss = fg_loss return rec_loss @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False) def calc_2d_rec_loss( self, input, gt, depth_fg_mask, test_mode=True, step=1, ignore_lpips=False, # conf_sigma=None, conf_sigma_l1=None, conf_sigma_percl=None, pred_alpha=None, ): opt = self.opt loss_dict = {} # logger.log(test_mode) # logger.log(input.min(), input.max(), gt.min(), gt.max()) if test_mode or not opt.fg_mse: rec_loss = self.criterionImg(input, gt) else: rec_loss = self.calc_mask_mse_loss( input, gt, depth_fg_mask, conf_sigma_l1=conf_sigma_l1, ) # conf_sigma_percl=conf_sigma_percl) # conf_sigma) # if step == 300: # st() if opt.lpips_lambda > 0 and step >= opt.lpips_delay_iter and not ignore_lpips: # tricky solution to avoid NAN in LPIPS loss # with torch.autocast(device_type='cuda', # dtype=torch.float16, # enabled=False): # if test_mode or not opt.fg_mse: # no need to calculate background lpips for ease of computation # inp_for_lpips = input * pred_alpha + torch.ones_like(input) * (1-pred_alpha) # gt_for_lpips = gt * depth_fg_mask + torch.ones_like(gt) * (1-depth_fg_mask) inp_for_lpips = input * pred_alpha gt_for_lpips = gt * depth_fg_mask width = input.shape[-1] if width == 192: # triplane here lpips_loss = self.criterionLPIPS( # loss on 128x128 center crop inp_for_lpips[:, :, width//2-64:width//2+64, width//2-64:width//2+64], gt_for_lpips[:, :, width//2-64:width//2+64, width//2-64:width//2+64], conf_sigma_percl=conf_sigma_percl, ) elif width >256: # elif width >192: # lpips_loss = self.criterionLPIPS( # F.interpolate(inp_for_lpips, (256,256), mode='bilinear'), # F.interpolate(gt_for_lpips, (256,256), mode='bilinear'), # conf_sigma_percl=conf_sigma_percl, # ) # patch = 80 # patch = 128 patch = 144 middle_point = width // 2 lpips_loss = self.criterionLPIPS( # loss on 128x128 center crop inp_for_lpips[:, :, middle_point-patch:middle_point+patch, middle_point-patch:middle_point+patch], gt_for_lpips[:, :, middle_point-patch:middle_point+patch, middle_point-patch:middle_point+patch], conf_sigma_percl=conf_sigma_percl, ) else: # directly supervise when <= 256 # ! add foreground mask assert pred_alpha is not None lpips_loss = self.criterionLPIPS( inp_for_lpips, gt_for_lpips, # conf_sigma_percl=conf_sigma_percl, ) # else: # fg lpips # assert depth_fg_mask.shape == input.shape # lpips_loss = self.criterionLPIPS( # input.contiguous() * depth_fg_mask, # gt.contiguous() * depth_fg_mask).mean() else: lpips_loss = torch.tensor(0., device=input.device) if opt.ssim_lambda > 0: loss_ssim = self.ssim_loss(input, gt) #? else: loss_ssim = torch.tensor(0., device=input.device) loss_psnr = self.psnr((input / 2 + 0.5), (gt / 2 + 0.5), 1.0) if opt.id_lambda > 0: loss_id = self._calc_loss_id(input, gt) else: loss_id = torch.tensor(0., device=input.device) if opt.l1_lambda > 0: loss_l1 = F.l1_loss(input, gt) else: loss_l1 = torch.tensor(0., device=input.device) # loss = rec_loss * opt.l2_lambda + lpips_loss * opt.lpips_lambda + loss_id * opt.id_lambda + loss_ssim * opt.ssim_lambda rec_loss = rec_loss * opt.l2_lambda loss = rec_loss + lpips_loss + loss_id * opt.id_lambda + loss_ssim * opt.ssim_lambda + opt.l1_lambda * loss_l1 # if return_dict: loss_dict['loss_l2'] = rec_loss loss_dict['loss_id'] = loss_id loss_dict['loss_lpips'] = lpips_loss loss_dict['loss'] = loss loss_dict['loss_ssim'] = loss_ssim # metrics to report, not involved in training loss_dict['mae'] = loss_l1 loss_dict['PSNR'] = loss_psnr loss_dict['SSIM'] = 1 - loss_ssim # Todo loss_dict['ID_SIM'] = 1 - loss_id return loss, loss_dict @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False) def calc_shape_rec_loss( self, pred_shape: dict, gt_shape: dict, device, ): """apply 3d shape reconstruction supervision. Basically supervise the densities with L1 loss Args: pred_shape (dict): dict contains reconstructed shape information gt_shape (dict): dict contains gt shape information supervise_sdf (bool, optional): whether supervise sdf rec. Defaults to True. supervise_surface_normal (bool, optional): whether supervise surface rec. Defaults to False. Returns: dict: shape reconstruction loss """ shape_loss_dict = {} shape_loss = 0 # assert supervise_sdf or supervise_surface_normal, 'should at least supervise one types of shape reconstruction' # todo, add weights if self.opt.shape_uniform_lambda > 0: shape_loss_dict['coarse'] = self.criterion3d_rec( pred_shape['coarse_densities'].squeeze(), gt_shape['coarse_densities'].squeeze()) shape_loss += shape_loss_dict[ 'coarse'] * self.opt.shape_uniform_lambda if self.opt.shape_importance_lambda > 0: shape_loss_dict['fine'] = self.criterion3d_rec( pred_shape['fine_densities'].squeeze(), # ? how to supervise gt_shape['fine_densities'].squeeze()) shape_loss += shape_loss_dict[ 'fine'] * self.opt.shape_importance_lambda loss_depth = self.criterion_alpha(pred_shape['image_depth'], gt_shape['image_depth']) shape_loss += loss_depth * self.opt.shape_depth_lambda shape_loss_dict.update(dict(loss_depth=loss_depth)) # TODO, add on surface pts supervision ? return shape_loss, shape_loss_dict @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False) def psnr(self, input, target, max_val): return kornia.metrics.psnr(input, target, max_val) # @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False) def ssim_loss(self, img1, img2, window_size=11, size_average=True): channel = img1.size(-3) window = create_window(window_size, channel) if img1.is_cuda: window = window.cuda(img1.get_device()) window = window.type_as(img1) return 1 - _ssim(img1, img2, window, window_size, channel, size_average) @torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False) def forward(self, pred, gt, test_mode=True, step=1, return_fg_mask=False, conf_sigma_l1=None, conf_sigma_percl=None, ignore_kl=False, ignore_lpips=False, *args, **kwargs): with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=False): loss = torch.tensor(0., device=self.device) loss_dict = {} if 'image_mask' in pred: pred_alpha = pred['image_mask'] # B 1 H W else: N, _, H, W = pred['image_depth'].shape pred_alpha = pred['weights_samples'].permute(0, 2, 1).reshape( N, 1, H, W) # balance rec_loss with logvar # if 'depth_mask' in gt: if self.opt.online_mask: # https://github.com/elliottwu/unsup3d/blob/dc961410d61684561f19525c2f7e9ee6f4dacb91/unsup3d/model.py#L193 margin = (self.opt.max_depth - self.opt.min_depth) / 2 fg_mask = (pred['image_depth'] < self.opt.max_depth + margin).float() # B 1 H W fg_mask = fg_mask.repeat_interleave(3, 1).float() else: if 'depth_mask' in gt: if gt['depth_mask'].shape[1] != 1: fg_mask = gt['depth_mask'].unsqueeze(1) else: fg_mask = gt['depth_mask'] fg_mask = fg_mask.repeat_interleave( 3, 1).float() else: fg_mask = None loss_2d, loss_2d_dict = self.calc_2d_rec_loss( pred['image_raw'], gt['img'], fg_mask, test_mode=test_mode, step=step, ignore_lpips=ignore_lpips, conf_sigma_l1=conf_sigma_l1, conf_sigma_percl=conf_sigma_percl, pred_alpha=pred_alpha, ) # ignore_lpips=self.opt.fg_mse) if self.opt.kl_lambda > 0 and not ignore_kl: # assert 'posterior' in pred, 'logvar' in pred assert 'posterior' in pred if self.opt.kl_anneal: kl_lambda = kl_coeff( step=step, constant_step=5e3, # 1w steps total_step=25e3, # 5w steps in total min_kl_coeff=max(1e-9, self.opt.kl_lambda / 1e4), max_kl_coeff=self.opt.kl_lambda) loss_dict['kl_lambda'] = kl_lambda else: loss_dict['kl_lambda'] = torch.tensor( self.opt.kl_lambda, device=dist_util.dev()) if self.opt.pt_ft_kl: pt_kl, ft_kl = pred['posterior'].kl(pt_ft_separate=True) kl_batch = pt_kl.shape[0] # loss_dict['kl_loss_pt'] = pt_kl.sum() * loss_dict['kl_lambda'] * 0.01 / kl_batch loss_dict['kl_loss_pt'] = pt_kl.sum() * loss_dict['kl_lambda'] * 0 # no compression at all. loss_dict['kl_loss_ft'] = ft_kl.sum() * loss_dict['kl_lambda'] / kl_batch loss = loss + loss_dict['kl_loss_pt'] + loss_dict['kl_loss_ft'] loss_dict['latent_mu_pt'] = pred['posterior'].mean[:, :3].mean() loss_dict['latent_std_pt'] = pred['posterior'].std[:, :3].mean() loss_dict['latent_mu_ft'] = pred['posterior'].mean[:, 3:].mean() loss_dict['latent_std_ft'] = pred['posterior'].std[:, 3:].mean() elif self.opt.ft_kl: ft_kl = pred['posterior'].kl(ft_separate=True) kl_batch = ft_kl.shape[0] loss_dict['kl_loss_ft'] = ft_kl.sum() * loss_dict['kl_lambda'] / kl_batch loss = loss + loss_dict['kl_loss_ft'] loss_dict['latent_mu_ft'] = pred['posterior'].mean[:, :].square().mean().float().detach() loss_dict['latent_std_ft'] = pred['posterior'].std[:, :].mean().float().detach() else: kl_loss = pred['posterior'].kl() kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] loss_dict['kl_loss'] = kl_loss * loss_dict['kl_lambda'] loss += loss_dict['kl_loss'] loss_dict['latent_mu'] = pred['posterior'].mean.mean() loss_dict['latent_std'] = pred['posterior'].std.mean() # nll_loss = loss_2d / torch.exp(pred['logvar']) + pred['logvar'] # nll_loss nll_loss = loss_2d loss += nll_loss loss_dict.update(dict(nll_loss=nll_loss)) # loss_dict['latent_mu'] = pred['latent_normalized'].mean() # loss_dict['latent_max'] = pred['latent_normalized'].max() # loss_dict['latent_min'] = pred['latent_normalized'].min() # loss_dict['latent_std'] = pred['latent_normalized'].std() # pred[ # 'latent_normalized_2Ddiffusion'].mean() # loss_dict['latent_std'] = pred[ # 'latent_normalized_2Ddiffusion'].std() # loss_dict['latent_max'] = pred[ # 'latent_normalized_2Ddiffusion'].max() # loss_dict['latent_min'] = pred[ # 'latent_normalized_2Ddiffusion'].min() else: loss += loss_2d # if 'image_sr' in pred and pred['image_sr'].shape==gt['img_sr']: if 'image_sr' in pred: if 'depth_mask_sr' in gt: depth_mask_sr = gt['depth_mask_sr'].unsqueeze( 1).repeat_interleave(3, 1).float() else: depth_mask_sr = None loss_sr, loss_sr_dict = self.calc_2d_rec_loss( pred['image_sr'], gt['img_sr'], depth_fg_mask=depth_mask_sr, # test_mode=test_mode, test_mode=True, step=step) loss_sr_lambda = 1 if step < self.opt.sr_delay_iter: loss_sr_lambda = 0 loss += loss_sr * loss_sr_lambda for k, v in loss_sr_dict.items(): loss_dict['sr_' + k] = v * loss_sr_lambda if self.opt.depth_lambda > 0: # TODO, switch to scale-agnostic depth loss assert 'depth' in gt pred_depth = pred['image_depth'] if pred_depth.ndim == 4: pred_depth = pred_depth.squeeze(1) # B H W # _, shape_loss_dict = self.calc_depth_loss( # pred_depth, gt['depth'], fg_mask[:, 0, ...]) _, shape_loss_dict = self.calc_scale_invariant_depth_loss( pred_depth, gt['depth'], fg_mask[:, 0, ...]) loss += shape_loss_dict['loss_depth'] loss_dict.update(shape_loss_dict) # if self.opt.latent_lambda > 0: # make sure the latent suits diffusion learning # latent_mu = pred['latent'].mean() # loss_latent = self.criterion_latent( # latent_mu, torch.zeros_like( # latent_mu)) # only regularize the mean value here # loss_dict['loss_latent'] = loss_latent # loss += loss_latent * self.opt.latent_lambda if self.opt.alpha_lambda > 0 and 'image_depth' in pred: loss_alpha = self.calc_alpha_loss(pred_alpha, fg_mask) loss_dict['loss_alpha'] = loss_alpha * self.opt.alpha_lambda loss += loss_alpha * self.opt.alpha_lambda if self.opt.depth_smoothness_lambda > 0: loss_depth_smoothness = depth_smoothness_loss( pred_alpha, pred['image_depth']) * self.opt.depth_smoothness_lambda loss_dict['loss_depth_smoothness'] = loss_depth_smoothness loss += loss_depth_smoothness loss_2d_dict['all_loss'] = loss loss_dict.update(loss_2d_dict) # if return_fg_mask: return loss, loss_dict, fg_mask # else: # return loss, loss_dict def _calc_loss_id(self, input, gt): if input.shape[-1] != 256: arcface_input = self.id_loss_pool(input) id_loss_gt = self.id_loss_pool(gt) else: arcface_input = input id_loss_gt = gt loss_id, _, _ = self.criterionID(arcface_input, id_loss_gt, id_loss_gt) return loss_id def calc_2d_rec_loss_misaligned(self, input, gt): """id loss + vgg loss Args: input (_type_): _description_ gt (_type_): _description_ depth_mask (_type_): _description_ test_mode (bool, optional): _description_. Defaults to True. """ opt = self.opt loss_dict = {} if opt.lpips_lambda > 0: with torch.autocast( device_type='cuda', dtype=torch.float16, enabled=False): # close AMP for lpips to avoid nan lpips_loss = self.criterionLPIPS(input, gt) else: lpips_loss = torch.tensor(0., device=input.device) if opt.id_lambda > 0: loss_id = self._calc_loss_id(input, gt) else: loss_id = torch.tensor(0., device=input.device) loss_dict['loss_id_real'] = loss_id loss_dict['loss_lpips_real'] = lpips_loss loss = lpips_loss * opt.lpips_lambda + loss_id * opt.id_lambda return loss, loss_dict class E3DGE_with_AdvLoss(E3DGELossClass): # adapted from sgm/modules/autoencoding/losses/discriminator_loss.py def __init__( self, device, opt, discriminator_config: Optional[Dict] = None, disc_num_layers: int = 3, disc_in_channels: int = 3, disc_start: int = 0, disc_loss: str = "hinge", disc_factor: float = 1.0, disc_weight: float = 1.0, regularization_weights: Union[None, Dict[str, float]] = None, dtype=torch.float32, # additional_log_keys: Optional[List[str]] = None, ) -> None: super().__init__( device, opt, ) # ! initialize GAN loss discriminator_config = default( discriminator_config, { "target": "nsr.losses.disc.NLayerDiscriminator", "params": { "input_nc": disc_in_channels, "n_layers": disc_num_layers, "use_actnorm": False, }, }, ) self.discriminator = instantiate_from_config( discriminator_config).apply(weights_init) self.discriminator_iter_start = disc_start self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss self.disc_factor = disc_factor self.discriminator_weight = disc_weight # self.regularization_weights = default(regularization_weights, {}) # self.forward_keys = [ # "optimizer_idx", # "global_step", # "last_layer", # "split", # "regularization_log", # ] # self.additional_log_keys = set(default(additional_log_keys, [])) # self.additional_log_keys.update(set( # self.regularization_weights.keys())) def get_trainable_parameters(self) -> Iterator[nn.Parameter]: return self.discriminator.parameters() def forward(self, pred, gt, behaviour: str, test_mode=True, step=1, return_fg_mask=False, conf_sigma_l1=None, conf_sigma_percl=None, ignore_d_loss=False, *args, **kwargs): # now the GAN part reconstructions = pred['image_raw'] inputs = gt['img'] if behaviour == 'g_step': nll_loss, loss_dict, fg_mask = super().forward( pred, gt, test_mode, step, return_fg_mask, conf_sigma_l1, conf_sigma_percl, *args, **kwargs) # generator update if not ignore_d_loss and (step >= self.discriminator_iter_start or not self.training): logits_fake = self.discriminator(reconstructions.contiguous()) g_loss = -torch.mean(logits_fake) if self.training: d_weight = torch.tensor(self.discriminator_weight) # d_weight = self.calculate_adaptive_weight( # nll_loss, g_loss, last_layer=last_layer) else: d_weight = torch.tensor(1.0) else: d_weight = torch.tensor(0.0) g_loss = torch.tensor(0.0, requires_grad=True) g_loss = g_loss * d_weight * self.disc_factor loss = nll_loss + g_loss # TODO loss_dict.update({ f"loss/g": g_loss.detach().mean(), }) # return loss, log return loss, loss_dict, fg_mask elif behaviour == 'd_step' and not ignore_d_loss: # second pass for discriminator update logits_real = self.discriminator(inputs.contiguous().detach()) logits_fake = self.discriminator( reconstructions.contiguous().detach()) if step >= self.discriminator_iter_start or not self.training: d_loss = self.disc_factor * self.disc_loss( logits_real, logits_fake) else: d_loss = torch.tensor(0.0, requires_grad=True) loss_dict = {} loss_dict.update({ "loss/disc": d_loss.clone().detach().mean(), "logits/real": logits_real.detach().mean(), "logits/fake": logits_fake.detach().mean(), }) return d_loss, loss_dict, None else: raise NotImplementedError(f"Unknown optimizer behaviour {behaviour}")