File size: 4,878 Bytes
98844c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import torch.nn as nn
from promptda.model.dpt import DPTHead
from promptda.model.config import model_configs
from promptda.utils.logger import Log
import os
from pathlib import Path
from huggingface_hub import hf_hub_download


class PromptDA(nn.Module):
    patch_size = 14  # patch size of the pretrained dinov2 model
    use_bn = False
    use_clstoken = False
    output_act = 'sigmoid'

    def __init__(self,
                 encoder='vitl',
                 ckpt_path='data/checkpoints/promptda_vitl.ckpt'):
        super().__init__()
        model_config = model_configs[encoder]

        self.encoder = encoder
        self.model_config = model_config
        self.pretrained = torch.hub.load(
            'torchhub/facebookresearch_dinov2_main',
            'dinov2_{:}14'.format(encoder),
            source='local',
            pretrained=False)
        dim = self.pretrained.blocks[0].attn.qkv.in_features
        self.depth_head = DPTHead(nclass=1,
                                  in_channels=dim,
                                  features=model_config['features'],
                                  out_channels=model_config['out_channels'],
                                  use_bn=self.use_bn,
                                  use_clstoken=self.use_clstoken,
                                  output_act=self.output_act)

        # mean and std of the pretrained dinov2 model
        self.register_buffer('_mean', torch.tensor(
            [0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer('_std', torch.tensor(
            [0.229, 0.224, 0.225]).view(1, 3, 1, 1))

        self.load_checkpoint(ckpt_path)
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path = None, model_kwargs = None, **hf_kwargs):
        """
        Load a model from a checkpoint file.
        ### Parameters:
        - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
        - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
        - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
        ### Returns:
        - A new instance of `MoGe` with the parameters loaded from the checkpoint.
        """
        ckpt_path = None
        if Path(pretrained_model_name_or_path).exists():
            ckpt_path = pretrained_model_name_or_path
        else:
            cached_checkpoint_path = hf_hub_download(
                repo_id=pretrained_model_name_or_path,
                repo_type="model",
                filename="promptda_vitl.ckpt",
                **hf_kwargs
            )
            ckpt_path = cached_checkpoint_path
        # model_config = checkpoint['model_config']
        # if model_kwargs is not None:
        #     model_config.update(model_kwargs)
        if model_kwargs is None:
            model_kwargs = {}
        model_kwargs.update({'ckpt_path': ckpt_path})
        model = cls(**model_kwargs)
        return model

    def load_checkpoint(self, ckpt_path):
        if os.path.exists(ckpt_path):
            Log.info(f'Loading checkpoint from {ckpt_path}')
            checkpoint = torch.load(ckpt_path, map_location='cpu')
            self.load_state_dict(
                {k[9:]: v for k, v in checkpoint['state_dict'].items()})
        else:
            Log.warn(f'Checkpoint {ckpt_path} not found')

    def forward(self, x, prompt_depth=None):
        assert prompt_depth is not None, 'prompt_depth is required'
        prompt_depth, min_val, max_val = self.normalize(prompt_depth)
        h, w = x.shape[-2:]
        features = self.pretrained.get_intermediate_layers(
            x, self.model_config['layer_idxs'],
            return_class_token=True)
        patch_h, patch_w = h // self.patch_size, w // self.patch_size
        depth = self.depth_head(features, patch_h, patch_w, prompt_depth)
        depth = self.denormalize(depth, min_val, max_val)
        return depth

    @torch.no_grad()
    def predict(self,
                image: torch.Tensor,
                prompt_depth: torch.Tensor):
        return self.forward(image, prompt_depth)

    def normalize(self,
                  prompt_depth: torch.Tensor):
        B, C, H, W = prompt_depth.shape
        min_val = torch.quantile(
            prompt_depth.reshape(B, -1), 0., dim=1, keepdim=True)[:, :, None, None]
        max_val = torch.quantile(
            prompt_depth.reshape(B, -1), 1., dim=1, keepdim=True)[:, :, None, None]
        prompt_depth = (prompt_depth - min_val) / (max_val - min_val)
        return prompt_depth, min_val, max_val

    def denormalize(self,
                    depth: torch.Tensor,
                    min_val: torch.Tensor,
                    max_val: torch.Tensor):
        return depth * (max_val - min_val) + min_val