|
import functools |
|
import logging |
|
import typing |
|
|
|
import beartype |
|
import torch |
|
from jaxtyping import Float, jaxtyped |
|
from torch import Tensor |
|
from torchvision.transforms import v2 |
|
|
|
logger = logging.getLogger("modeling.py") |
|
|
|
|
|
@jaxtyped(typechecker=beartype.beartype) |
|
class SplitDinov2(torch.nn.Module): |
|
def __init__(self, *, split_at: int): |
|
super().__init__() |
|
|
|
self.vit = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg").eval() |
|
self.split_at = split_at |
|
|
|
def forward_start( |
|
self, x: Float[Tensor, "batch channels width height"] |
|
) -> Float[Tensor, "batch total_patches dim"]: |
|
x_BPD = self.vit.prepare_tokens_with_masks(x) |
|
for blk in self.vit.blocks[: self.split_at]: |
|
x_BPD = blk(x_BPD) |
|
|
|
return x_BPD |
|
|
|
def forward_end( |
|
self, x_BPD: Float[Tensor, "batch total_patches dim"] |
|
) -> Float[Tensor, "batch patches dim"]: |
|
for blk in self.vit.blocks[-self.split_at :]: |
|
x_BPD = blk(x_BPD) |
|
|
|
x_BPD = self.vit.norm(x_BPD) |
|
return x_BPD[:, self.vit.num_register_tokens + 1 :] |
|
|
|
|
|
@functools.cache |
|
def load_vit(device: str) -> tuple[SplitDinov2, typing.Callable]: |
|
vit = SplitDinov2(split_at=11).to(device) |
|
vit_transform = v2.Compose([ |
|
v2.Resize(size=(256, 256)), |
|
v2.CenterCrop(size=(224, 224)), |
|
v2.ToImage(), |
|
v2.ToDtype(torch.float32, scale=True), |
|
v2.Normalize(mean=[0.4850, 0.4560, 0.4060], std=[0.2290, 0.2240, 0.2250]), |
|
]) |
|
logger.info("Loaded ViT.") |
|
|
|
return vit, vit_transform |
|
|