|
|
|
|
|
|
|
|
|
|
|
import itertools |
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
_DINOV2_BASE_URL = 'https://dl.fbaipublicfiles.com/dinov2' |
|
|
|
|
|
def _make_dinov2_model_name( |
|
arch_name: str, patch_size: int, num_register_tokens: int = 0 |
|
) -> str: |
|
compact_arch_name = arch_name.replace('_', '')[:4] |
|
registers_suffix = ( |
|
f'_reg{num_register_tokens}' if num_register_tokens else '' |
|
) |
|
return f'dinov2_{compact_arch_name}{patch_size}{registers_suffix}' |
|
|
|
|
|
class CenterPadding(nn.Module): |
|
def __init__(self, multiple): |
|
super().__init__() |
|
self.multiple = multiple |
|
|
|
def _get_pad(self, size): |
|
new_size = math.ceil(size / self.multiple) * self.multiple |
|
pad_size = new_size - size |
|
pad_size_left = pad_size // 2 |
|
pad_size_right = pad_size - pad_size_left |
|
return pad_size_left, pad_size_right |
|
|
|
@torch.inference_mode() |
|
def forward(self, x): |
|
pads = list( |
|
itertools.chain.from_iterable( |
|
self._get_pad(m) for m in x.shape[:1:-1] |
|
) |
|
) |
|
output = F.pad(x, pads) |
|
return output |
|
|