Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from promptda.model.dpt import DPTHead | |
from promptda.model.config import model_configs | |
from promptda.utils.logger import Log | |
import os | |
from pathlib import Path | |
from huggingface_hub import hf_hub_download | |
class PromptDA(nn.Module): | |
patch_size = 14 # patch size of the pretrained dinov2 model | |
use_bn = False | |
use_clstoken = False | |
output_act = 'sigmoid' | |
def __init__(self, | |
encoder='vitl', | |
ckpt_path='data/checkpoints/promptda_vitl.ckpt'): | |
super().__init__() | |
model_config = model_configs[encoder] | |
self.encoder = encoder | |
self.model_config = model_config | |
self.pretrained = torch.hub.load( | |
'torchhub/facebookresearch_dinov2_main', | |
'dinov2_{:}14'.format(encoder), | |
source='local', | |
pretrained=False) | |
dim = self.pretrained.blocks[0].attn.qkv.in_features | |
self.depth_head = DPTHead(nclass=1, | |
in_channels=dim, | |
features=model_config['features'], | |
out_channels=model_config['out_channels'], | |
use_bn=self.use_bn, | |
use_clstoken=self.use_clstoken, | |
output_act=self.output_act) | |
# mean and std of the pretrained dinov2 model | |
self.register_buffer('_mean', torch.tensor( | |
[0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
self.register_buffer('_std', torch.tensor( | |
[0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
self.load_checkpoint(ckpt_path) | |
def from_pretrained(cls, pretrained_model_name_or_path = None, model_kwargs = None, **hf_kwargs): | |
""" | |
Load a model from a checkpoint file. | |
### Parameters: | |
- `pretrained_model_name_or_path`: path to the checkpoint file or repo id. | |
- `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. | |
- `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. | |
### Returns: | |
- A new instance of `MoGe` with the parameters loaded from the checkpoint. | |
""" | |
ckpt_path = None | |
if Path(pretrained_model_name_or_path).exists(): | |
ckpt_path = pretrained_model_name_or_path | |
else: | |
cached_checkpoint_path = hf_hub_download( | |
repo_id=pretrained_model_name_or_path, | |
repo_type="model", | |
filename="promptda_vitl.ckpt", | |
**hf_kwargs | |
) | |
ckpt_path = cached_checkpoint_path | |
# model_config = checkpoint['model_config'] | |
# if model_kwargs is not None: | |
# model_config.update(model_kwargs) | |
if model_kwargs is None: | |
model_kwargs = {} | |
model_kwargs.update({'ckpt_path': ckpt_path}) | |
model = cls(**model_kwargs) | |
return model | |
def load_checkpoint(self, ckpt_path): | |
if os.path.exists(ckpt_path): | |
Log.info(f'Loading checkpoint from {ckpt_path}') | |
checkpoint = torch.load(ckpt_path, map_location='cpu') | |
self.load_state_dict( | |
{k[9:]: v for k, v in checkpoint['state_dict'].items()}) | |
else: | |
Log.warn(f'Checkpoint {ckpt_path} not found') | |
def forward(self, x, prompt_depth=None): | |
assert prompt_depth is not None, 'prompt_depth is required' | |
prompt_depth, min_val, max_val = self.normalize(prompt_depth) | |
h, w = x.shape[-2:] | |
features = self.pretrained.get_intermediate_layers( | |
x, self.model_config['layer_idxs'], | |
return_class_token=True) | |
patch_h, patch_w = h // self.patch_size, w // self.patch_size | |
depth = self.depth_head(features, patch_h, patch_w, prompt_depth) | |
depth = self.denormalize(depth, min_val, max_val) | |
return depth | |
def predict(self, | |
image: torch.Tensor, | |
prompt_depth: torch.Tensor): | |
return self.forward(image, prompt_depth) | |
def normalize(self, | |
prompt_depth: torch.Tensor): | |
B, C, H, W = prompt_depth.shape | |
min_val = torch.quantile( | |
prompt_depth.reshape(B, -1), 0., dim=1, keepdim=True)[:, :, None, None] | |
max_val = torch.quantile( | |
prompt_depth.reshape(B, -1), 1., dim=1, keepdim=True)[:, :, None, None] | |
prompt_depth = (prompt_depth - min_val) / (max_val - min_val) | |
return prompt_depth, min_val, max_val | |
def denormalize(self, | |
depth: torch.Tensor, | |
min_val: torch.Tensor, | |
max_val: torch.Tensor): | |
return depth * (max_val - min_val) + min_val | |