define-hf-demo / vidar /arch /losses /SmoothnessLoss.py
Jiading Fang
add define
fc16538
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
from abc import ABC
import torch
from vidar.arch.losses.BaseLoss import BaseLoss
from vidar.utils.tensor import same_shape, interpolate_image
class SmoothnessLoss(BaseLoss, ABC):
"""
Smoothness loss class
Parameters
----------
cfg : Config
Configuration with parameters
"""
def __init__(self, cfg):
super().__init__(cfg)
self.normalize = cfg.normalize
def calculate(self, rgb, depth):
"""
Calculate smoothness loss
Parameters
----------
rgb : torch.Tensor
Input image [B,3,H,W]
depth : torch.Tensor
Predicted depth map [B,1,H,W]
Returns
-------
loss : torch.Tensor
Smoothness loss [1]
"""
if self.normalize:
mean_depth = depth.mean(2, True).mean(3, True)
norm_depth = depth / (mean_depth + 1e-7)
else:
norm_depth = depth
grad_depth_x = torch.abs(norm_depth[:, :, :, :-1] - norm_depth[:, :, :, 1:])
grad_depth_y = torch.abs(norm_depth[:, :, :-1, :] - norm_depth[:, :, 1:, :])
grad_rgb_x = torch.mean(torch.abs(rgb[:, :, :, :-1] - rgb[:, :, :, 1:]), 1, keepdim=True)
grad_rgb_y = torch.mean(torch.abs(rgb[:, :, :-1, :] - rgb[:, :, 1:, :]), 1, keepdim=True)
grad_depth_x *= torch.exp(-1.0 * grad_rgb_x)
grad_depth_y *= torch.exp(-1.0 * grad_rgb_y)
return grad_depth_x.mean() + grad_depth_y.mean()
def forward(self, rgb, depth):
"""
Calculate smoothness loss
Parameters
----------
rgb : list[torch.Tensor]
Input images [B,3,H,W]
depth : list[torch.Tensor]
Predicted depth maps [B,1,H,W]
Returns
-------
output : Dict
Dictionary with loss and metrics
"""
scales = self.get_scales(rgb)
weights = self.get_weights(scales)
losses, metrics = [], {}
for i in range(scales):
rgb_i, depth_i = rgb[i], depth[i]
if not same_shape(rgb_i.shape[-2:], depth_i.shape[-2:]):
rgb_i = interpolate_image(rgb_i, shape=depth_i.shape[-2:])
loss_i = weights[i] * self.calculate(rgb_i, depth_i)
metrics[f'smoothness_loss/{i}'] = loss_i.detach()
losses.append(loss_i)
loss = sum(losses) / len(losses)
return {
'loss': loss,
'metrics': metrics,
}