from collections import OrderedDict import torch import torch.nn as nn import abc from transformers import PerceiverModel, PerceiverConfig from transformers.models.perceiver.modeling_perceiver import build_position_encoding from vidar.arch.networks.perceiver.externals.modeling_perceiver import PerceiverDepthDecoder, PerceiverRGBDecoder, build_position_encoding from vidar.arch.blocks.depth.SigmoidToInvDepth import SigmoidToInvDepth from vidar.arch.networks.decoders.DepthDecoder import DepthDecoder from vidar.arch.networks.encoders.ResNetEncoder import ResNetEncoder from vidar.utils.config import Config from vidar.utils.networks import freeze_layers_and_norms from vidar.utils.tensor import interpolate from vidar.utils.types import is_int class DownSampleRGB(nn.Module): def __init__(self, out_dim): super().__init__() self.conv = torch.nn.Conv2d(3, out_dim, kernel_size=7, stride=2, padding=3) self.norm = torch.nn.BatchNorm2d(out_dim) self.actv = torch.nn.ReLU() self.pool = torch.nn.MaxPool2d(2, stride=2) def forward(self, x): x = self.conv(x) x = self.norm(x) x = self.actv(x) x = self.pool(x) return x class DefineNet(nn.Module): def __init__(self, cfg): super().__init__() self.tasks = cfg.tasks self.to_world = cfg.to_world self.depth_range = cfg.depth_range self.rgb_feat_dim = cfg.rgb_feat_dim self.rgb_feat_type = cfg.rgb_feat_type self.encoder_with_rgb = cfg.encoder_with_rgb self.decoder_with_rgb = cfg.decoder_with_rgb self.output_mode = cfg.output_mode self.sample_encoding_rays = cfg.sample_encoding_rays self.with_monodepth = cfg.with_monodepth self.upsample_convex = cfg.upsample_convex self.downsample_encoder = cfg.downsample_encoder self.downsample_decoder = cfg.downsample_decoder self.image_shape = [s // self.downsample_encoder for s in cfg.image_shape] self.fourier_encoding_orig, _ = build_position_encoding( position_encoding_type='fourier', fourier_position_encoding_kwargs={ 'num_bands': cfg.num_bands_orig, 'max_resolution': [cfg.max_resolution_orig] * 3, 'concat_pos': True, 'sine_only': False, } ) self.fourier_encoding_dirs, _ = build_position_encoding( position_encoding_type='fourier', fourier_position_encoding_kwargs={ 'num_bands': cfg.num_bands_dirs, 'max_resolution': [cfg.num_bands_dirs] * 3, 'concat_pos': True, 'sine_only': False, } ) tot_encoder = self.fourier_encoding_orig.output_size() + \ self.fourier_encoding_dirs.output_size() if self.encoder_with_rgb: tot_encoder += self.rgb_feat_dim tot_decoder = self.fourier_encoding_orig.output_size() + \ self.fourier_encoding_dirs.output_size() if self.decoder_with_rgb: tot_decoder += self.rgb_feat_dim tot_decoder_depth = tot_decoder tot_decoder_rgb = tot_decoder self.config = PerceiverConfig( train_size=self.image_shape, d_latents=cfg.d_latents, d_model=tot_encoder, num_latents=cfg.num_latents, hidden_act='gelu', hidden_dropout_prob=cfg.hidden_dropout_prob, initializer_range=0.02, layer_norm_eps=1e-12, num_blocks=1, num_cross_attention_heads=cfg.num_cross_attention_heads, num_self_attends_per_block=cfg.num_self_attends_per_block, num_self_attention_heads=cfg.num_self_attention_heads, qk_channels=None, v_channels=None, ) if 'depth' in self.tasks: self.decoder = PerceiverDepthDecoder( self.config, num_channels=tot_decoder_depth, use_query_residual=False, output_num_channels=1, position_encoding_type="none", min_depth=self.depth_range[0], max_depth=self.depth_range[1], num_heads=cfg.decoder_num_heads, upsample_mode=cfg.upsample_convex, upsample_value=cfg.downsample_decoder, output_mode=cfg.output_mode ) if 'rgb' in self.tasks: self.decoder_rgb = PerceiverRGBDecoder( self.config, num_channels=tot_decoder_rgb, use_query_residual=False, output_num_channels=3, position_encoding_type="none", num_heads=cfg.decoder_num_heads, upsample_mode=cfg.upsample_convex, upsample_value=cfg.downsample_decoder, ) self.model = PerceiverModel( self.config, ) if self.rgb_feat_type == 'convnet': self.feature = DownSampleRGB(out_dim=self.rgb_feat_dim) elif self.rgb_feat_type in ['resnet', 'resnet_all', 'resnet_all_rgb']: self.feature = ResNetEncoder(Config(version=18, pretrained=True, num_rgb_in=1)) if self.with_monodepth: self.mono_encoder = ResNetEncoder(Config(version=18, pretrained=True, num_rgb_in=1)) self.mono_decoder = DepthDecoder(Config( num_scales=4, use_skips=True, num_ch_enc=self.feature.num_ch_enc, num_ch_out=1, activation='sigmoid', )) self.sigmoid_to_depth = SigmoidToInvDepth( min_depth=self.depth_range[0], max_depth=self.depth_range[1], return_depth=True) def get_rgb_feat(self, rgb): if self.rgb_feat_type == 'convnet': return { 'feat': self.feature(rgb) } elif self.rgb_feat_type == 'resnet': return { 'feat': self.feature(rgb)[1] } elif self.rgb_feat_type.startswith('resnet_all'): all_feats = self.feature(rgb) feats = all_feats[1:] for i in range(1, len(feats)): feats[i] = interpolate( feats[i], size=feats[0], scale_factor=None, mode='bilinear', align_corners=True) if self.rgb_feat_type.endswith('rgb'): feats = feats + [interpolate( rgb, size=feats[0], scale_factor=None, mode='bilinear', align_corners=True)] feat = torch.cat(feats, 1) return { 'all_feats': all_feats, 'feat': feat } def run_monodepth(self, rgb, freeze): freeze_layers_and_norms(self.mono_encoder, flag_freeze=freeze) freeze_layers_and_norms(self.mono_decoder, flag_freeze=freeze) mono_features = self.mono_encoder(rgb) mono_output = self.mono_decoder(mono_features) sigmoids = [mono_output[('output', i)] for i in range(1)] return self.sigmoid_to_depth(sigmoids)[0] def embeddings(self, data, sources, downsample): if 'rgb' in sources: assert 'rgb' in data[0].keys() b = [datum['rgb'].shape[0] for datum in data] rgb = torch.cat([datum['rgb'] for datum in data], 0) output_feats = self.get_rgb_feat(rgb) feats = torch.split(output_feats['feat'], b) for i in range(len(data)): data[i]['feat'] = feats[i] if self.with_monodepth: depth = self.run_monodepth(rgb, freeze=False) depth = torch.split(depth, b) for i in range(len(data)): data[i]['depth_mono'] = depth[i] encodings = [] for datum in data: encoding = OrderedDict() if 'cam' in sources: assert 'cam' in data[0].keys() cam = datum['cam'].scaled(1. / downsample) orig = cam.get_origin(flatten=True) if self.to_world: dirs = cam.get_viewdirs(normalize=True, flatten=True, to_world=True) else: dirs = cam.no_translation().get_viewdirs(normalize=True, flatten=True, to_world=True) orig_encodings = self.fourier_encoding_orig( index_dims=None, pos=orig, batch_size=orig.shape[0], device=orig.device) dirs_encodings = self.fourier_encoding_dirs( index_dims=None, pos=dirs, batch_size=dirs.shape[0], device=dirs.device) encoding['cam'] = torch.cat([orig_encodings, dirs_encodings], -1) if 'rgb' in sources: rgb = datum['feat'] rgb_flat = rgb.view(*rgb.shape[:-2], -1).permute(0, 2, 1) encoding['rgb'] = rgb_flat encoding['all'] = torch.cat([val for val in encoding.values()], -1) encodings.append(encoding) return encodings @staticmethod def sample_decoder(data, embeddings, field, sample_queries, filter_invalid): query_idx = [] if filter_invalid: tot_min = [] for i in range(len(embeddings)): for b in range(data[i]['rgb'].shape[0]): tot_min.append((data[i]['rgb'][b].mean(0) >= 0).sum()) tot_min = min(tot_min) tot = embeddings[0][field][0].shape[0] tot = int(sample_queries * tot) tot = min([tot, tot_min]) for i in range(len(embeddings)): idx = [] for b in range(data[i]['rgb'].shape[0]): if filter_invalid: valid = data[i]['rgb'][b].mean(0, keepdim=True) >= 0 valid = valid.view(1, -1).permute(1, 0) num = embeddings[i][field][0].shape[0] all_idx = torch.arange(num, device=valid.device).unsqueeze(1) valid_idx = all_idx[valid] num = valid_idx.shape[0] idx_i = torch.randperm(num)[tot:] valid[valid_idx[idx_i]] = 0 idx_i = all_idx[valid] else: num = embeddings[i][field][0].shape[0] tot = int(sample_queries * num) idx_i = torch.randperm(num)[:tot] idx.append(idx_i) idx = torch.stack(idx, 0) embeddings[i][field] = torch.stack( [embeddings[i][field][b][idx[b]] for b in range(idx.shape[0])], 0) query_idx.append(idx) return query_idx, embeddings def forward(self, encode_data, decode_data=None, sample_queries=0, filter_invalid=False): encode_field = 'all' if self.encoder_with_rgb else 'cam' decode_field = 'all' if self.decoder_with_rgb else 'cam' encode_sources = ['rgb', 'cam'] decode_sources = ['cam'] shape = encode_data[0]['cam'].hw output = {} encode_dict = self.encode( data=encode_data, field=encode_field, sources=encode_sources ) if 'depth_mono' in encode_data[0].keys(): output['depth_mono'] = [datum['depth_mono'] for datum in encode_data] decode_embeddings = encode_dict['embeddings'] if decode_data is None else None decode_dict = self.decode( latent=encode_dict['latent'], shape=shape, data=decode_data, embeddings=decode_embeddings, field=decode_field, sources=decode_sources, sample_queries=sample_queries, filter_invalid=filter_invalid ) output.update(decode_dict['output']) return { 'output': output, 'encode_embeddings': encode_dict['embeddings'], 'decode_embeddings': decode_dict['embeddings'], 'latent': encode_dict['latent'], } def encode(self, field, sources, data=None, embeddings=None): assert data is not None or embeddings is not None assert data is None or embeddings is None if embeddings is None: embeddings = self.embeddings(data, sources=sources, downsample=self.downsample_encoder) all_embeddings = torch.cat([emb[field] for emb in embeddings], 1) if self.training and self.sample_encoding_rays > 0: tot = self.sample_encoding_rays if is_int(self.sample_encoding_rays) \ else int(self.sample_encoding_rays * all_embeddings.shape[1]) all_embeddings = torch.stack([all_embeddings[i, torch.randperm(all_embeddings.shape[1])[:tot], :] for i in range(all_embeddings.shape[0])], 0) return { 'embeddings': embeddings, 'latent': self.model(inputs=all_embeddings).last_hidden_state, } def decode(self, latent, field, sources=None, data=None, embeddings=None, shape=None, sample_queries=0, filter_invalid=False): assert data is not None or embeddings is not None assert data is None or embeddings is None if embeddings is None: shape = data[0]['cam'].hw shape = [s // self.downsample_decoder for s in shape] embeddings = self.embeddings(data, sources=sources, downsample=self.downsample_decoder) output = {} if self.training and (sample_queries > 0): # or filter_invalid): output['query_idx'], embeddings = self.sample_decoder( data, embeddings, field, sample_queries, filter_invalid) shape = None if 'rgb' in self.tasks: output['rgb'] = [ self.decoder_rgb(query=emb[field], z=latent, shape=shape).logits for emb in embeddings] if 'depth' in self.tasks: output['depth'] = [ self.decoder(query=emb[field], z=latent, shape=shape).logits for emb in embeddings] return { 'embeddings': embeddings, 'output': output, }