# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # dpt head implementation for DUST3R # Downstream heads assume inputs of size B x N x C (where N is the number of tokens) ; # or if it takes as input the output at every layer, the attribute return_all_layers should be set to True # the forward function also takes as input a dictionnary img_info with key "height" and "width" # for PixelwiseTask, the output will be of dimension B x num_channels x H x W # -------------------------------------------------------- from einops import rearrange from typing import List import torch import torch.nn as nn from utils.dust3r.heads.postprocess import postprocess # import utils.dust3r.utils.path_to_croco # noqa: F401 from utils.dust3r.dpt_block import DPTOutputAdapter # noqa from pdb import set_trace as st class DPTOutputAdapter_fix(DPTOutputAdapter): """ Adapt croco's DPTOutputAdapter implementation for dust3r: remove duplicated weigths, and fix forward for dust3r """ def init(self, dim_tokens_enc=768): super().init(dim_tokens_enc) # these are duplicated weights del self.act_1_postprocess del self.act_2_postprocess del self.act_3_postprocess del self.act_4_postprocess def forward(self, encoder_tokens: List[torch.Tensor], image_size=None): assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first' # H, W = input_info['image_size'] image_size = self.image_size if image_size is None else image_size H, W = image_size # Number of patches in height and width N_H = H // (self.stride_level * self.P_H) N_W = W // (self.stride_level * self.P_W) # Hook decoder onto 4 layers from specified ViT layers layers = [encoder_tokens[hook] for hook in self.hooks] # Extract only task-relevant tokens and ignore global tokens. layers = [self.adapt_tokens(l) for l in layers] # Reshape tokens to spatial representation layers = [ rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers ] # st() layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)] # Project layers to chosen feature dim layers = [ self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers) ] # Fuse layers using refinement stages path_4 = self.scratch.refinenet4( layers[3])[:, :, :layers[2].shape[2], :layers[2].shape[3]] path_3 = self.scratch.refinenet3(path_4, layers[2]) path_2 = self.scratch.refinenet2(path_3, layers[1]) path_1 = self.scratch.refinenet1(path_2, layers[0]) # Output head out = self.head(path_1) return out class PixelwiseTaskWithDPT(nn.Module): """ DPT module for dust3r, can return 3D points + confidence for all pixels""" def __init__(self, *, n_cls_token=0, hooks_idx=None, dim_tokens=None, output_width_ratio=1, num_channels=1, postprocess=None, depth_mode=None, conf_mode=None, **kwargs): super(PixelwiseTaskWithDPT, self).__init__() self.return_all_layers = True # backbone needs to return all layers self.postprocess = postprocess self.depth_mode = depth_mode self.conf_mode = conf_mode assert n_cls_token == 0, "Not implemented" dpt_args = dict(output_width_ratio=output_width_ratio, num_channels=num_channels, **kwargs) if hooks_idx is not None: dpt_args.update(hooks=hooks_idx) self.dpt = DPTOutputAdapter_fix(**dpt_args) dpt_init_args = {} if dim_tokens is None else { 'dim_tokens_enc': dim_tokens } self.dpt.init(**dpt_init_args) # ! remove unused param del self.dpt.scratch.refinenet4.resConfUnit1 def forward(self, x, img_info): out = self.dpt(x, image_size=(img_info[0], img_info[1])) if self.postprocess: out = self.postprocess(out, self.depth_mode, self.conf_mode) return out def create_dpt_head(net, has_conf=False): """ return PixelwiseTaskWithDPT for given net params """ assert net.dec_depth > 9 l2 = net.dec_depth feature_dim = 256 last_dim = feature_dim // 2 out_nchan = 3 ed = net.enc_embed_dim dd = net.dec_embed_dim return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, feature_dim=feature_dim, last_dim=last_dim, hooks_idx=[0, l2 * 2 // 4, l2 * 3 // 4, l2], dim_tokens=[ed, dd, dd, dd], postprocess=postprocess, # postprocess=None, depth_mode=net.depth_mode, conf_mode=net.conf_mode, head_type='regression') # def create_dpt_head_ln3diff(net, has_conf=False): def create_dpt_head_ln3diff(out_nchan, feature_dim, l2, dec_embed_dim, patch_size=2, has_conf=False): """ return PixelwiseTaskWithDPT for given net params """ # assert net.dec_depth > 9 # l2 = net.dec_depth # feature_dim = 256 last_dim = feature_dim // 2 # out_nchan = 3 # ed = net.enc_embed_dim # dd = net.dec_embed_dim dd = dec_embed_dim return PixelwiseTaskWithDPT(num_channels=out_nchan + has_conf, feature_dim=feature_dim, last_dim=last_dim, patch_size=patch_size, hooks_idx=[(l2 * 1 // 4)-1, (l2 * 2 // 4)-1, (l2 * 3 // 4)-1, l2-1], # dim_tokens=[ed, dd, dd, dd], dim_tokens=[dd, dd, dd, dd], # postprocess=postprocess, postprocess=None, # depth_mode=net.depth_mode, # conf_mode=net.conf_mode, head_type='regression_gs')