File size: 2,573 Bytes
fc16538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# TRI-VIDAR - Copyright 2022 Toyota Research Institute.  All rights reserved.

from abc import ABC

import torch

from vidar.arch.losses.BaseLoss import BaseLoss
from vidar.utils.tensor import same_shape, interpolate_image


class SmoothnessLoss(BaseLoss, ABC):
    """
    Smoothness loss class

    Parameters
    ----------
    cfg : Config
        Configuration with parameters
    """
    def __init__(self, cfg):
        super().__init__(cfg)
        self.normalize = cfg.normalize

    def calculate(self, rgb, depth):
        """
        Calculate smoothness loss

        Parameters
        ----------
        rgb : torch.Tensor
            Input image [B,3,H,W]
        depth : torch.Tensor
            Predicted depth map [B,1,H,W]

        Returns
        -------
        loss : torch.Tensor
            Smoothness loss [1]
        """
        if self.normalize:
            mean_depth = depth.mean(2, True).mean(3, True)
            norm_depth = depth / (mean_depth + 1e-7)
        else:
            norm_depth = depth

        grad_depth_x = torch.abs(norm_depth[:, :, :, :-1] - norm_depth[:, :, :, 1:])
        grad_depth_y = torch.abs(norm_depth[:, :, :-1, :] - norm_depth[:, :, 1:, :])

        grad_rgb_x = torch.mean(torch.abs(rgb[:, :, :, :-1] - rgb[:, :, :, 1:]), 1, keepdim=True)
        grad_rgb_y = torch.mean(torch.abs(rgb[:, :, :-1, :] - rgb[:, :, 1:, :]), 1, keepdim=True)

        grad_depth_x *= torch.exp(-1.0 * grad_rgb_x)
        grad_depth_y *= torch.exp(-1.0 * grad_rgb_y)

        return grad_depth_x.mean() + grad_depth_y.mean()

    def forward(self, rgb, depth):
        """
        Calculate smoothness loss

        Parameters
        ----------
        rgb : list[torch.Tensor]
            Input images [B,3,H,W]
        depth : list[torch.Tensor]
            Predicted depth maps [B,1,H,W]

        Returns
        -------
        output : Dict
            Dictionary with loss and metrics
        """
        scales = self.get_scales(rgb)
        weights = self.get_weights(scales)

        losses, metrics = [], {}

        for i in range(scales):
            rgb_i, depth_i = rgb[i], depth[i]
            if not same_shape(rgb_i.shape[-2:], depth_i.shape[-2:]):
                rgb_i = interpolate_image(rgb_i, shape=depth_i.shape[-2:])

            loss_i = weights[i] * self.calculate(rgb_i, depth_i)

            metrics[f'smoothness_loss/{i}'] = loss_i.detach()
            losses.append(loss_i)

        loss = sum(losses) / len(losses)

        return {
            'loss': loss,
            'metrics': metrics,
        }