# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved. from abc import ABC from vidar.arch.blocks.image.ViewSynthesis import ViewSynthesis from vidar.arch.models.BaseModel import BaseModel from vidar.arch.models.utils import make_rgb_scales, create_cameras from vidar.utils.data import get_from_dict from vidar.utils.config import cfg_has class SelfSupervisedModel(BaseModel, ABC): """ Self-supervised depth estimation model Parameters ---------- cfg : Config Configuration with parameters """ def __init__(self, cfg): super().__init__(cfg) self.view_synthesis = ViewSynthesis() self.set_attr(cfg.model, 'use_gt_pose', False) self.set_attr(cfg.model, 'use_gt_intrinsics', True) if not self.use_gt_intrinsics: self.camera_model = cfg_has(cfg.networks.intrinsics, 'camera_model', 'UCM') if self.camera_model == 'UCM': from vidar.geometry.camera_ucm import UCMCamera self.camera_class = UCMCamera elif self.camera_model == 'EUCM': from vidar.geometry.camera_eucm import EUCMCamera self.camera_class = EUCMCamera elif self.camera_model == 'DS': from vidar.geometry.camera_ds import DSCamera self.camera_class = DSCamera else: raise NotImplementedError('Invalid camera type') def forward(self, batch, epoch=0): """Model forward pass""" rgb = batch['rgb'] if self.use_gt_intrinsics: intrinsics = get_from_dict(batch, 'intrinsics') else: intrinsics = self.networks['intrinsics'](rgb=rgb[0]) valid_mask = get_from_dict(batch, 'mask') if self.use_gt_intrinsics: depth_output = self.networks['depth'](rgb=rgb[0], intrinsics=intrinsics[0]) else: depth_output = self.networks['depth'](rgb=rgb[0]) pred_depth = depth_output['depths'] predictions = { 'depth': {0: pred_depth}, } pred_logvar = get_from_dict(depth_output, 'logvar') if pred_logvar is not None: predictions['logvar'] = {0: pred_logvar} if not self.training: return { 'predictions': predictions, } if self.use_gt_pose: assert 'pose' in batch, 'You need input pose' pose = batch['pose'] elif 'pose' in self.networks: pose = self.compute_pose(rgb, self.networks['pose'], tgt=0, invert=True) predictions['pose'] = pose else: pose = None if not self.use_gt_intrinsics: cams = {0: self.camera_class(I=intrinsics)} for key in pose.keys(): cams[key] = self.camera_class(I=intrinsics, Tcw=pose[key]) else: cams = create_cameras(rgb[0], intrinsics[0], pose) gt_depth = None if 'depth' not in batch else batch['depth'][0] loss, metrics = self.compute_loss_and_metrics( rgb, pred_depth, cams, gt_depth=gt_depth, logvar=pred_logvar, valid_mask=valid_mask ) if not self.use_gt_intrinsics: if self.camera_model == 'UCM': fx, fy, cx, cy, alpha = intrinsics[0].squeeze() intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'alpha':alpha} metrics.update(intrinsics_metrics) elif self.camera_model == 'EUCM': fx, fy, cx, cy, alpha, beta = intrinsics[0].squeeze() intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'alpha':alpha, 'beta':beta} metrics.update(intrinsics_metrics) elif self.camera_model == 'DS': fx, fy, cx, cy, xi, alpha = intrinsics[0].squeeze() intrinsics_metrics = {'fx': fx, 'fy':fy, 'cx':cx, 'cy':cy, 'xi':xi, 'alpha':alpha} metrics.update(intrinsics_metrics) else: raise NotImplementedError('Invalid camera type') return { 'loss': loss, 'metrics': metrics, 'predictions': predictions, } def compute_loss_and_metrics(self, rgb, depths, cams, gt_depth=None, logvar=None, valid_mask=None): """ Compute loss and metrics for training Parameters ---------- rgb : Dict Dictionary with input images [B,3,H,W] depths : list[torch.Tensor] List with target depth maps in different scales [B,1,H,W] cams : Dict Dictionary with cameras for each input image gt_depth : torch.Tensor Ground-truth depth map for supervised training logvar : list[torch.Tensor] Log-variance maps for uncertainty training valid_mask : list[torch.Tensor] Binary mask for masking out invalid pixels [B,1,H,W] Returns ------- loss : torch.Tensor Training loss metrics : Dict Dictionary with training metrics """ tgt = 0 ctx = [key for key in rgb.keys() if key != tgt] num_scales = self.get_num_scales(depths) rgbs = make_rgb_scales(rgb, pyramid=depths) rgb_tgt = [rgbs[tgt][i] for i in range(num_scales)] rgb_ctx = [[rgbs[j][i] for j in ctx] for i in range(num_scales)] loss, metrics = [], {} if 'reprojection' in self.losses: synth = self.view_synthesis( rgbs, depths=depths, cams=cams, return_masks=True) reprojection_output = self.losses['reprojection']( rgb_tgt, rgb_ctx, synth['warps'], logvar=logvar, valid_mask=valid_mask, overlap_mask=synth['masks']) loss.append(reprojection_output['loss']) metrics.update(**reprojection_output['metrics']) if 'smoothness' in self.losses: smoothness_output = self.losses['smoothness'](rgb_tgt, depths) loss.append(smoothness_output['loss']) metrics.update(**smoothness_output['metrics']) if 'supervision' in self.losses and gt_depth is not None: supervision_output = self.losses['supervision'](depths, gt_depth) loss.append(supervision_output['loss']) metrics.update(**supervision_output['metrics']) if 'normals' in self.losses and gt_depth is not None: normals_output = self.losses['normals'](depths, gt_depth, cams[0]) loss.append(normals_output['loss']) metrics.update(**normals_output['metrics']) loss = sum(loss) return loss, metrics