dmitriitochilkin's picture
add dependencies
ff49a48
raw
history blame
2.14 kB
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
from einops import rearrange
from huggingface_hub import hf_hub_download
from transformers.models.vit.modeling_vit import ViTModel
from ...utils import BaseModule
class DINOSingleImageTokenizer(BaseModule):
@dataclass
class Config(BaseModule.Config):
pretrained_model_name_or_path: str = "facebook/dino-vitb16"
enable_gradient_checkpointing: bool = False
cfg: Config
def configure(self) -> None:
self.model: ViTModel = ViTModel(
ViTModel.config_class.from_pretrained(
hf_hub_download(
repo_id=self.cfg.pretrained_model_name_or_path,
filename="config.json",
)
)
)
if self.cfg.enable_gradient_checkpointing:
self.model.encoder.gradient_checkpointing = True
self.register_buffer(
"image_mean",
torch.as_tensor([0.485, 0.456, 0.406]).reshape(1, 1, 3, 1, 1),
persistent=False,
)
self.register_buffer(
"image_std",
torch.as_tensor([0.229, 0.224, 0.225]).reshape(1, 1, 3, 1, 1),
persistent=False,
)
def forward(self, images: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
packed = False
if images.ndim == 4:
packed = True
images = images.unsqueeze(1)
batch_size, n_input_views = images.shape[:2]
images = (images - self.image_mean) / self.image_std
out = self.model(
rearrange(images, "B N C H W -> (B N) C H W"), interpolate_pos_encoding=True
)
local_features, global_features = out.last_hidden_state, out.pooler_output
local_features = local_features.permute(0, 2, 1)
local_features = rearrange(
local_features, "(B N) Ct Nt -> B N Ct Nt", B=batch_size
)
if packed:
local_features = local_features.squeeze(1)
return local_features
def detokenize(self, *args, **kwargs):
raise NotImplementedError