add files
Browse files- README.md +47 -3
- __init__.py +17 -0
- data.py +130 -0
- environment.yml +285 -0
- hub/__init__.py +4 -0
- hub/backbones.py +201 -0
- hub/utils.py +46 -0
- hubconf.py +29 -0
- main.py +238 -0
- models/dinov2/__init__.py +46 -0
- models/dinov2/layers/__init__.py +11 -0
- models/dinov2/layers/attention.py +91 -0
- models/dinov2/layers/block.py +304 -0
- models/dinov2/layers/dino_head.py +69 -0
- models/dinov2/layers/drop_path.py +37 -0
- models/dinov2/layers/layer_scale.py +26 -0
- models/dinov2/layers/mlp.py +40 -0
- models/dinov2/layers/patch_embed.py +100 -0
- models/dinov2/layers/swiglu_ffn.py +69 -0
- models/dinov2/vision_transformer.py +454 -0
- neighbor_loss.py +137 -0
- pyproject.toml +17 -0
- repair.py +147 -0
- setup.cfg +23 -0
- singular_defect.py +149 -0
- utils.py +70 -0
- visualize.py +82 -0
README.md
CHANGED
@@ -1,3 +1,47 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Official code for SINDER: Repairing the Singular Defects of DINOv2 (ECCV 2024 Oral)
|
2 |
+
|
3 |
+
[![🦢 - Paper](https://img.shields.io/badge/🦢-Paper-red)](https://arxiv.org/abs/2407.16826)
|
4 |
+
|
5 |
+
![SINDER](./resources/high_norm.jpg)
|
6 |
+
|
7 |
+
## Citation
|
8 |
+
|
9 |
+
```bibtex
|
10 |
+
@InProceedings{Haoqi_2024_ECCV,
|
11 |
+
author = {Wang, Haoqi and Zhang, Tong and Salzmann, Mathieu},
|
12 |
+
title = {{SINDER}: Repairing the Singular Defects of DINOv2},
|
13 |
+
booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
|
14 |
+
month = {September},
|
15 |
+
year = {2024}
|
16 |
+
}
|
17 |
+
```
|
18 |
+
|
19 |
+
## Install
|
20 |
+
|
21 |
+
```bash
|
22 |
+
conda env create -f environment.yml
|
23 |
+
conda activate sinder
|
24 |
+
pip install -e .
|
25 |
+
```
|
26 |
+
|
27 |
+
## Train
|
28 |
+
|
29 |
+
Put ImageNet-1K dataset at `data/imagenet` folder.
|
30 |
+
|
31 |
+
```bash
|
32 |
+
./main.py
|
33 |
+
```
|
34 |
+
|
35 |
+
## Visualize
|
36 |
+
|
37 |
+
```bash
|
38 |
+
# visualize original dinov2
|
39 |
+
./visualize.py resources/example.jpg
|
40 |
+
|
41 |
+
# visualize sinder, ckpt download link below
|
42 |
+
./visualize.py resources/example.jpg --checkpoint path/to/sinder.pth
|
43 |
+
```
|
44 |
+
|
45 |
+
## Checkpoints
|
46 |
+
|
47 |
+
[Google Drive](https://drive.google.com/file/d/1g0Aq5qXYuMmVrN9-gGwC9ybwlCDFAw-l/view?usp=sharing)
|
__init__.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .data import load_data, load_visual_data
|
2 |
+
from .neighbor_loss import get_neighbor_loss
|
3 |
+
from .repair import replace_back, replace_linear_addition_noqk
|
4 |
+
from .singular_defect import singular_defect_directions
|
5 |
+
from .utils import get_tokens, load_model, pca_array
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
'singular_defect_directions',
|
9 |
+
'get_neighbor_loss',
|
10 |
+
'pca_array',
|
11 |
+
'get_tokens',
|
12 |
+
'load_data',
|
13 |
+
'load_visual_data',
|
14 |
+
'replace_back',
|
15 |
+
'replace_linear_addition_noqk',
|
16 |
+
'load_model',
|
17 |
+
]
|
data.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from glob import glob
|
2 |
+
|
3 |
+
import torchvision.transforms as transforms
|
4 |
+
from PIL import Image
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
def make_transform(
|
9 |
+
smaller_edge_size: int, patch_size, center_crop=False, max_edge_size=812
|
10 |
+
) -> transforms.Compose:
|
11 |
+
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
|
12 |
+
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
|
13 |
+
interpolation_mode = transforms.InterpolationMode.BICUBIC
|
14 |
+
assert smaller_edge_size > 0
|
15 |
+
|
16 |
+
if center_crop:
|
17 |
+
return transforms.Compose(
|
18 |
+
[
|
19 |
+
transforms.Resize(
|
20 |
+
size=smaller_edge_size,
|
21 |
+
interpolation=interpolation_mode,
|
22 |
+
antialias=True,
|
23 |
+
),
|
24 |
+
transforms.CenterCrop(smaller_edge_size),
|
25 |
+
transforms.ToTensor(),
|
26 |
+
transforms.Normalize(
|
27 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
28 |
+
),
|
29 |
+
transforms.Lambda(
|
30 |
+
lambda img: img[
|
31 |
+
:,
|
32 |
+
: min(
|
33 |
+
max_edge_size,
|
34 |
+
(img.shape[1] - img.shape[1] % patch_size),
|
35 |
+
),
|
36 |
+
: min(
|
37 |
+
max_edge_size,
|
38 |
+
(img.shape[2] - img.shape[2] % patch_size),
|
39 |
+
),
|
40 |
+
]
|
41 |
+
),
|
42 |
+
]
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
return transforms.Compose(
|
46 |
+
[
|
47 |
+
transforms.Resize(
|
48 |
+
size=smaller_edge_size,
|
49 |
+
interpolation=interpolation_mode,
|
50 |
+
antialias=True,
|
51 |
+
),
|
52 |
+
transforms.ToTensor(),
|
53 |
+
transforms.Normalize(
|
54 |
+
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
|
55 |
+
),
|
56 |
+
transforms.Lambda(
|
57 |
+
lambda img: img[
|
58 |
+
:,
|
59 |
+
: min(
|
60 |
+
max_edge_size,
|
61 |
+
(img.shape[1] - img.shape[1] % patch_size),
|
62 |
+
),
|
63 |
+
: min(
|
64 |
+
max_edge_size,
|
65 |
+
(img.shape[2] - img.shape[2] % patch_size),
|
66 |
+
),
|
67 |
+
]
|
68 |
+
),
|
69 |
+
]
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
class VisualDataset(Dataset):
|
74 |
+
def __init__(self, transform, imgs=None):
|
75 |
+
self.transform = transform
|
76 |
+
if imgs is None:
|
77 |
+
self.files = [
|
78 |
+
'resources/example.jpg',
|
79 |
+
'resources/villa.png',
|
80 |
+
'resources/000000037740.jpg',
|
81 |
+
'resources/000000064359.jpg',
|
82 |
+
'resources/000000066635.jpg',
|
83 |
+
'resources/000000078420.jpg',
|
84 |
+
]
|
85 |
+
else:
|
86 |
+
self.files = imgs
|
87 |
+
|
88 |
+
def __len__(self):
|
89 |
+
return len(self.files)
|
90 |
+
|
91 |
+
def __getitem__(self, index):
|
92 |
+
img = self.files[index]
|
93 |
+
img = Image.open(img).convert('RGB')
|
94 |
+
if self.transform:
|
95 |
+
img = self.transform(img)
|
96 |
+
return img
|
97 |
+
|
98 |
+
|
99 |
+
class ImageNetDataset(Dataset):
|
100 |
+
def __init__(self, transform, num_train_max=1000000):
|
101 |
+
self.transform = transform
|
102 |
+
self.files = glob('data/imagenet/train/*/*.JPEG')
|
103 |
+
step = len(self.files) // num_train_max
|
104 |
+
self.files = self.files[::step]
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.files)
|
108 |
+
|
109 |
+
def __getitem__(self, index):
|
110 |
+
img = Image.open(self.files[index]).convert('RGB')
|
111 |
+
img = self.transform(img)
|
112 |
+
return img
|
113 |
+
|
114 |
+
|
115 |
+
def load_data(args, model):
|
116 |
+
transform = make_transform(
|
117 |
+
args.resolution, model.patch_size, center_crop=True
|
118 |
+
)
|
119 |
+
dataset = ImageNetDataset(
|
120 |
+
transform=transform, num_train_max=args.num_train_max
|
121 |
+
)
|
122 |
+
return dataset
|
123 |
+
|
124 |
+
|
125 |
+
def load_visual_data(args, model):
|
126 |
+
transform = make_transform(
|
127 |
+
args.visual_size, model.patch_size, max_edge_size=1792
|
128 |
+
)
|
129 |
+
dataset = VisualDataset(transform=transform, imgs=vars(args).get('imgs'))
|
130 |
+
return dataset
|
environment.yml
ADDED
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: sinder
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
dependencies:
|
7 |
+
- _libgcc_mutex=0.1
|
8 |
+
- _openmp_mutex=4.5
|
9 |
+
- alsa-lib=1.2.12
|
10 |
+
- aom=3.9.1
|
11 |
+
- asttokens=2.4.1
|
12 |
+
- black=22.1.0
|
13 |
+
- blas=2.116
|
14 |
+
- blas-devel=3.9.0
|
15 |
+
- blue=0.9.1
|
16 |
+
- brotli-python=1.1.0
|
17 |
+
- bzip2=1.0.8
|
18 |
+
- c-ares=1.33.1
|
19 |
+
- ca-certificates=2024.8.30
|
20 |
+
- cairo=1.18.0
|
21 |
+
- certifi=2024.8.30
|
22 |
+
- cffi=1.16.0
|
23 |
+
- cfgv=3.3.1
|
24 |
+
- charset-normalizer=3.3.2
|
25 |
+
- click=8.1.7
|
26 |
+
- colorama=0.4.6
|
27 |
+
- cuda-cudart=12.1.105
|
28 |
+
- cuda-cupti=12.1.105
|
29 |
+
- cuda-libraries=12.1.0
|
30 |
+
- cuda-nvrtc=12.1.105
|
31 |
+
- cuda-nvtx=12.1.105
|
32 |
+
- cuda-opencl=12.5.39
|
33 |
+
- cuda-runtime=12.1.0
|
34 |
+
- cuda-version=12.5
|
35 |
+
- dataclasses=0.8
|
36 |
+
- dav1d=1.2.1
|
37 |
+
- dbus=1.13.6
|
38 |
+
- decorator=5.1.1
|
39 |
+
- distlib=0.3.8
|
40 |
+
- double-conversion=3.3.0
|
41 |
+
- exceptiongroup=1.2.2
|
42 |
+
- executing=2.0.1
|
43 |
+
- expat=2.6.2
|
44 |
+
- ffmpeg=6.1.2
|
45 |
+
- filelock=3.15.4
|
46 |
+
- flake8=4.0.1
|
47 |
+
- font-ttf-dejavu-sans-mono=2.37
|
48 |
+
- font-ttf-inconsolata=3.000
|
49 |
+
- font-ttf-source-code-pro=2.038
|
50 |
+
- font-ttf-ubuntu=0.83
|
51 |
+
- fontconfig=2.14.2
|
52 |
+
- fonts-conda-ecosystem=1
|
53 |
+
- fonts-conda-forge=1
|
54 |
+
- freeglut=3.2.2
|
55 |
+
- freetype=2.12.1
|
56 |
+
- fribidi=1.0.10
|
57 |
+
- gettext=0.22.5
|
58 |
+
- gettext-tools=0.22.5
|
59 |
+
- gmp=6.3.0
|
60 |
+
- gmpy2=2.1.5
|
61 |
+
- gnutls=3.7.9
|
62 |
+
- graphite2=1.3.13
|
63 |
+
- h2=4.1.0
|
64 |
+
- harfbuzz=9.0.0
|
65 |
+
- hdf5=1.14.3
|
66 |
+
- hpack=4.0.0
|
67 |
+
- hyperframe=6.0.1
|
68 |
+
- icu=75.1
|
69 |
+
- identify=2.6.0
|
70 |
+
- idna=3.7
|
71 |
+
- imath=3.1.12
|
72 |
+
- importlib-metadata=8.5.0
|
73 |
+
- ipdb=0.13.13
|
74 |
+
- ipython=8.26.0
|
75 |
+
- jasper=4.2.4
|
76 |
+
- jedi=0.19.1
|
77 |
+
- jinja2=3.1.4
|
78 |
+
- joblib=1.4.2
|
79 |
+
- keyutils=1.6.1
|
80 |
+
- krb5=1.21.3
|
81 |
+
- lame=3.100
|
82 |
+
- lcms2=2.16
|
83 |
+
- ld_impl_linux-64=2.40
|
84 |
+
- lerc=4.0.0
|
85 |
+
- libabseil=20240116.2
|
86 |
+
- libaec=1.1.3
|
87 |
+
- libasprintf=0.22.5
|
88 |
+
- libasprintf-devel=0.22.5
|
89 |
+
- libass=0.17.3
|
90 |
+
- libblas=3.9.0
|
91 |
+
- libcblas=3.9.0
|
92 |
+
- libclang-cpp18.1=18.1.8
|
93 |
+
- libclang13=19.1.0
|
94 |
+
- libcublas=12.1.0.26
|
95 |
+
- libcufft=11.0.2.4
|
96 |
+
- libcufile=1.10.1.7
|
97 |
+
- libcups=2.3.3
|
98 |
+
- libcurand=10.3.6.82
|
99 |
+
- libcurl=8.10.1
|
100 |
+
- libcusolver=11.4.4.55
|
101 |
+
- libcusparse=12.0.2.55
|
102 |
+
- libdeflate=1.21
|
103 |
+
- libdrm=2.4.122
|
104 |
+
- libedit=3.1.20191231
|
105 |
+
- libegl=1.7.0
|
106 |
+
- libev=4.33
|
107 |
+
- libexpat=2.6.2
|
108 |
+
- libffi=3.4.2
|
109 |
+
- libgcc=14.1.0
|
110 |
+
- libgcc-ng=14.1.0
|
111 |
+
- libgettextpo=0.22.5
|
112 |
+
- libgettextpo-devel=0.22.5
|
113 |
+
- libgfortran=14.1.0
|
114 |
+
- libgfortran-ng=14.1.0
|
115 |
+
- libgfortran5=14.1.0
|
116 |
+
- libgl=1.7.0
|
117 |
+
- libglib=2.80.3
|
118 |
+
- libglu=9.0.0
|
119 |
+
- libglvnd=1.7.0
|
120 |
+
- libglx=1.7.0
|
121 |
+
- libgomp=14.1.0
|
122 |
+
- libhwloc=2.11.1
|
123 |
+
- libiconv=1.17
|
124 |
+
- libidn2=2.3.7
|
125 |
+
- libjpeg-turbo=3.0.0
|
126 |
+
- liblapack=3.9.0
|
127 |
+
- liblapacke=3.9.0
|
128 |
+
- libllvm18=18.1.8
|
129 |
+
- libllvm19=19.1.0
|
130 |
+
- libnghttp2=1.58.0
|
131 |
+
- libnpp=12.0.2.50
|
132 |
+
- libnsl=2.0.1
|
133 |
+
- libnvjitlink=12.1.105
|
134 |
+
- libnvjpeg=12.1.1.14
|
135 |
+
- libopencv=4.10.0
|
136 |
+
- libopenvino=2024.4.0
|
137 |
+
- libopenvino-auto-batch-plugin=2024.4.0
|
138 |
+
- libopenvino-auto-plugin=2024.4.0
|
139 |
+
- libopenvino-hetero-plugin=2024.4.0
|
140 |
+
- libopenvino-intel-cpu-plugin=2024.4.0
|
141 |
+
- libopenvino-intel-gpu-plugin=2024.4.0
|
142 |
+
- libopenvino-intel-npu-plugin=2024.4.0
|
143 |
+
- libopenvino-ir-frontend=2024.4.0
|
144 |
+
- libopenvino-onnx-frontend=2024.4.0
|
145 |
+
- libopenvino-paddle-frontend=2024.4.0
|
146 |
+
- libopenvino-pytorch-frontend=2024.4.0
|
147 |
+
- libopenvino-tensorflow-frontend=2024.4.0
|
148 |
+
- libopenvino-tensorflow-lite-frontend=2024.4.0
|
149 |
+
- libopus=1.3.1
|
150 |
+
- libpciaccess=0.18
|
151 |
+
- libpng=1.6.44
|
152 |
+
- libpq=16.4
|
153 |
+
- libprotobuf=4.25.3
|
154 |
+
- libsqlite=3.46.0
|
155 |
+
- libssh2=1.11.0
|
156 |
+
- libstdcxx=14.1.0
|
157 |
+
- libstdcxx-ng=14.1.0
|
158 |
+
- libtasn1=4.19.0
|
159 |
+
- libtiff=4.7.0
|
160 |
+
- libunistring=0.9.10
|
161 |
+
- libuuid=2.38.1
|
162 |
+
- libva=2.22.0
|
163 |
+
- libvpx=1.14.1
|
164 |
+
- libwebp-base=1.4.0
|
165 |
+
- libxcb=1.16
|
166 |
+
- libxcrypt=4.4.36
|
167 |
+
- libxkbcommon=1.7.0
|
168 |
+
- libxml2=2.12.7
|
169 |
+
- libzlib=1.3.1
|
170 |
+
- llvm-openmp=15.0.7
|
171 |
+
- markupsafe=2.1.5
|
172 |
+
- matplotlib-inline=0.1.7
|
173 |
+
- mccabe=0.6.1
|
174 |
+
- mkl=2022.1.0
|
175 |
+
- mkl-devel=2022.1.0
|
176 |
+
- mkl-include=2022.1.0
|
177 |
+
- mpc=1.3.1
|
178 |
+
- mpfr=4.2.1
|
179 |
+
- mpmath=1.3.0
|
180 |
+
- mypy_extensions=1.0.0
|
181 |
+
- mysql-common=9.0.1
|
182 |
+
- mysql-libs=9.0.1
|
183 |
+
- ncurses=6.5
|
184 |
+
- nettle=3.9.1
|
185 |
+
- networkx=3.3
|
186 |
+
- nodeenv=1.9.1
|
187 |
+
- numpy=2.0.0
|
188 |
+
- ocl-icd=2.3.2
|
189 |
+
- opencv=4.10.0
|
190 |
+
- openexr=3.2.2
|
191 |
+
- openh264=2.4.1
|
192 |
+
- openjpeg=2.5.2
|
193 |
+
- openssl=3.3.2
|
194 |
+
- p11-kit=0.24.1
|
195 |
+
- parso=0.8.4
|
196 |
+
- pathspec=0.12.1
|
197 |
+
- pcre2=10.44
|
198 |
+
- pexpect=4.9.0
|
199 |
+
- pickleshare=0.7.5
|
200 |
+
- pillow=10.4.0
|
201 |
+
- pip=24.0
|
202 |
+
- pixman=0.43.2
|
203 |
+
- platformdirs=4.2.2
|
204 |
+
- pre-commit=3.7.1
|
205 |
+
- prompt-toolkit=3.0.47
|
206 |
+
- pthread-stubs=0.4
|
207 |
+
- ptyprocess=0.7.0
|
208 |
+
- pugixml=1.14
|
209 |
+
- pure_eval=0.2.2
|
210 |
+
- py-opencv=4.10.0
|
211 |
+
- pycodestyle=2.8.0
|
212 |
+
- pycparser=2.22
|
213 |
+
- pyflakes=2.4.0
|
214 |
+
- pygments=2.18.0
|
215 |
+
- pysocks=1.7.1
|
216 |
+
- python=3.10.14
|
217 |
+
- python_abi=3.10
|
218 |
+
- pytorch=2.3.1
|
219 |
+
- pytorch-cuda=12.1
|
220 |
+
- pytorch-mutex=1.0
|
221 |
+
- pyyaml=6.0.1
|
222 |
+
- qt6-main=6.7.2
|
223 |
+
- readline=8.2
|
224 |
+
- requests=2.32.3
|
225 |
+
- scikit-learn=1.5.1
|
226 |
+
- scipy=1.14.0
|
227 |
+
- setuptools=71.0.4
|
228 |
+
- six=1.16.0
|
229 |
+
- snappy=1.2.1
|
230 |
+
- stack_data=0.6.2
|
231 |
+
- svt-av1=2.2.1
|
232 |
+
- sympy=1.13.0
|
233 |
+
- tbb=2021.13.0
|
234 |
+
- threadpoolctl=3.5.0
|
235 |
+
- tk=8.6.13
|
236 |
+
- toml=0.10.2
|
237 |
+
- tomli=2.0.1
|
238 |
+
- torchaudio=2.3.1
|
239 |
+
- torchtriton=2.3.1
|
240 |
+
- torchvision=0.18.1
|
241 |
+
- tqdm=4.66.5
|
242 |
+
- traitlets=5.14.3
|
243 |
+
- typed-ast=1.5.5
|
244 |
+
- typing_extensions=4.12.2
|
245 |
+
- tzdata=2024a
|
246 |
+
- ukkonen=1.0.1
|
247 |
+
- urllib3=2.2.2
|
248 |
+
- virtualenv=20.26.3
|
249 |
+
- wayland=1.23.0
|
250 |
+
- wayland-protocols=1.36
|
251 |
+
- wcwidth=0.2.13
|
252 |
+
- wheel=0.43.0
|
253 |
+
- x264=1!164.3095
|
254 |
+
- x265=3.5
|
255 |
+
- xcb-util=0.4.1
|
256 |
+
- xcb-util-cursor=0.1.5
|
257 |
+
- xcb-util-image=0.4.0
|
258 |
+
- xcb-util-keysyms=0.4.1
|
259 |
+
- xcb-util-renderutil=0.3.10
|
260 |
+
- xcb-util-wm=0.4.2
|
261 |
+
- xkeyboard-config=2.42
|
262 |
+
- xorg-fixesproto=5.0
|
263 |
+
- xorg-inputproto=2.3.2
|
264 |
+
- xorg-kbproto=1.0.7
|
265 |
+
- xorg-libice=1.1.1
|
266 |
+
- xorg-libsm=1.2.4
|
267 |
+
- xorg-libx11=1.8.9
|
268 |
+
- xorg-libxau=1.0.11
|
269 |
+
- xorg-libxdmcp=1.1.3
|
270 |
+
- xorg-libxext=1.3.4
|
271 |
+
- xorg-libxfixes=5.0.3
|
272 |
+
- xorg-libxi=1.7.10
|
273 |
+
- xorg-libxrender=0.9.11
|
274 |
+
- xorg-libxtst=1.2.5
|
275 |
+
- xorg-libxxf86vm=1.1.5
|
276 |
+
- xorg-recordproto=1.14.2
|
277 |
+
- xorg-renderproto=0.11.1
|
278 |
+
- xorg-xextproto=7.3.0
|
279 |
+
- xorg-xproto=7.0.31
|
280 |
+
- xz=5.2.6
|
281 |
+
- yaml=0.2.5
|
282 |
+
- zipp=3.20.2
|
283 |
+
- zlib=1.3.1
|
284 |
+
- zstandard=0.23.0
|
285 |
+
- zstd=1.5.6
|
hub/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
hub/backbones.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
12 |
+
|
13 |
+
|
14 |
+
class Weights(Enum):
|
15 |
+
LVD142M = 'LVD142M'
|
16 |
+
|
17 |
+
|
18 |
+
def _make_dinov2_model(
|
19 |
+
*,
|
20 |
+
arch_name: str = 'vit_large',
|
21 |
+
img_size: int = 518,
|
22 |
+
patch_size: int = 14,
|
23 |
+
init_values: float = 1.0,
|
24 |
+
ffn_layer: str = 'mlp',
|
25 |
+
block_chunks: int = 0,
|
26 |
+
num_register_tokens: int = 0,
|
27 |
+
interpolate_antialias: bool = False,
|
28 |
+
interpolate_offset: float = 0.1,
|
29 |
+
pretrained: bool = True,
|
30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
from ..models.dinov2 import vision_transformer as vits
|
34 |
+
|
35 |
+
if isinstance(weights, str):
|
36 |
+
try:
|
37 |
+
weights = Weights[weights]
|
38 |
+
except KeyError:
|
39 |
+
raise AssertionError(f'Unsupported weights: {weights}')
|
40 |
+
|
41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
42 |
+
vit_kwargs = dict(
|
43 |
+
img_size=img_size,
|
44 |
+
patch_size=patch_size,
|
45 |
+
init_values=init_values,
|
46 |
+
ffn_layer=ffn_layer,
|
47 |
+
block_chunks=block_chunks,
|
48 |
+
num_register_tokens=num_register_tokens,
|
49 |
+
interpolate_antialias=interpolate_antialias,
|
50 |
+
interpolate_offset=interpolate_offset,
|
51 |
+
)
|
52 |
+
vit_kwargs.update(**kwargs)
|
53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
54 |
+
|
55 |
+
if pretrained:
|
56 |
+
model_full_name = _make_dinov2_model_name(
|
57 |
+
arch_name, patch_size, num_register_tokens
|
58 |
+
)
|
59 |
+
url = (
|
60 |
+
_DINOV2_BASE_URL
|
61 |
+
+ f'/{model_base_name}/{model_full_name}_pretrain.pth'
|
62 |
+
)
|
63 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
64 |
+
url, map_location='cpu'
|
65 |
+
)
|
66 |
+
model.load_state_dict(state_dict, strict=True)
|
67 |
+
|
68 |
+
return model
|
69 |
+
|
70 |
+
|
71 |
+
def dinov2_vits14(
|
72 |
+
*,
|
73 |
+
pretrained: bool = True,
|
74 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
75 |
+
**kwargs,
|
76 |
+
):
|
77 |
+
"""DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M
|
78 |
+
dataset."""
|
79 |
+
return _make_dinov2_model(
|
80 |
+
arch_name='vit_small', pretrained=pretrained, weights=weights, **kwargs
|
81 |
+
)
|
82 |
+
|
83 |
+
|
84 |
+
def dinov2_vitb14(
|
85 |
+
*,
|
86 |
+
pretrained: bool = True,
|
87 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
88 |
+
**kwargs,
|
89 |
+
):
|
90 |
+
"""DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M
|
91 |
+
dataset."""
|
92 |
+
return _make_dinov2_model(
|
93 |
+
arch_name='vit_base', pretrained=pretrained, weights=weights, **kwargs
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def dinov2_vitl14(
|
98 |
+
*,
|
99 |
+
pretrained: bool = True,
|
100 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
101 |
+
**kwargs,
|
102 |
+
):
|
103 |
+
"""DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M
|
104 |
+
dataset."""
|
105 |
+
return _make_dinov2_model(
|
106 |
+
arch_name='vit_large', pretrained=pretrained, weights=weights, **kwargs
|
107 |
+
)
|
108 |
+
|
109 |
+
|
110 |
+
def dinov2_vitg14(
|
111 |
+
*,
|
112 |
+
pretrained: bool = True,
|
113 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
114 |
+
**kwargs,
|
115 |
+
):
|
116 |
+
"""DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M
|
117 |
+
dataset."""
|
118 |
+
return _make_dinov2_model(
|
119 |
+
arch_name='vit_giant2',
|
120 |
+
ffn_layer='swiglufused',
|
121 |
+
weights=weights,
|
122 |
+
pretrained=pretrained,
|
123 |
+
**kwargs,
|
124 |
+
)
|
125 |
+
|
126 |
+
|
127 |
+
def dinov2_vits14_reg(
|
128 |
+
*,
|
129 |
+
pretrained: bool = True,
|
130 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
131 |
+
**kwargs,
|
132 |
+
):
|
133 |
+
"""DINOv2 ViT-S/14 model with registers (optionally) pretrained on the
|
134 |
+
LVD-142M dataset."""
|
135 |
+
return _make_dinov2_model(
|
136 |
+
arch_name='vit_small',
|
137 |
+
pretrained=pretrained,
|
138 |
+
weights=weights,
|
139 |
+
num_register_tokens=4,
|
140 |
+
interpolate_antialias=True,
|
141 |
+
interpolate_offset=0.0,
|
142 |
+
**kwargs,
|
143 |
+
)
|
144 |
+
|
145 |
+
|
146 |
+
def dinov2_vitb14_reg(
|
147 |
+
*,
|
148 |
+
pretrained: bool = True,
|
149 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
150 |
+
**kwargs,
|
151 |
+
):
|
152 |
+
"""DINOv2 ViT-B/14 model with registers (optionally) pretrained on the
|
153 |
+
LVD-142M dataset."""
|
154 |
+
return _make_dinov2_model(
|
155 |
+
arch_name='vit_base',
|
156 |
+
pretrained=pretrained,
|
157 |
+
weights=weights,
|
158 |
+
num_register_tokens=4,
|
159 |
+
interpolate_antialias=True,
|
160 |
+
interpolate_offset=0.0,
|
161 |
+
**kwargs,
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
def dinov2_vitl14_reg(
|
166 |
+
*,
|
167 |
+
pretrained: bool = True,
|
168 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
169 |
+
**kwargs,
|
170 |
+
):
|
171 |
+
"""DINOv2 ViT-L/14 model with registers (optionally) pretrained on the
|
172 |
+
LVD-142M dataset."""
|
173 |
+
return _make_dinov2_model(
|
174 |
+
arch_name='vit_large',
|
175 |
+
pretrained=pretrained,
|
176 |
+
weights=weights,
|
177 |
+
num_register_tokens=4,
|
178 |
+
interpolate_antialias=True,
|
179 |
+
interpolate_offset=0.0,
|
180 |
+
**kwargs,
|
181 |
+
)
|
182 |
+
|
183 |
+
|
184 |
+
def dinov2_vitg14_reg(
|
185 |
+
*,
|
186 |
+
pretrained: bool = True,
|
187 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
188 |
+
**kwargs,
|
189 |
+
):
|
190 |
+
"""DINOv2 ViT-g/14 model with registers (optionally) pretrained on the
|
191 |
+
LVD-142M dataset."""
|
192 |
+
return _make_dinov2_model(
|
193 |
+
arch_name='vit_giant2',
|
194 |
+
ffn_layer='swiglufused',
|
195 |
+
weights=weights,
|
196 |
+
pretrained=pretrained,
|
197 |
+
num_register_tokens=4,
|
198 |
+
interpolate_antialias=True,
|
199 |
+
interpolate_offset=0.0,
|
200 |
+
**kwargs,
|
201 |
+
)
|
hub/utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import itertools
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
_DINOV2_BASE_URL = 'https://dl.fbaipublicfiles.com/dinov2'
|
14 |
+
|
15 |
+
|
16 |
+
def _make_dinov2_model_name(
|
17 |
+
arch_name: str, patch_size: int, num_register_tokens: int = 0
|
18 |
+
) -> str:
|
19 |
+
compact_arch_name = arch_name.replace('_', '')[:4]
|
20 |
+
registers_suffix = (
|
21 |
+
f'_reg{num_register_tokens}' if num_register_tokens else ''
|
22 |
+
)
|
23 |
+
return f'dinov2_{compact_arch_name}{patch_size}{registers_suffix}'
|
24 |
+
|
25 |
+
|
26 |
+
class CenterPadding(nn.Module):
|
27 |
+
def __init__(self, multiple):
|
28 |
+
super().__init__()
|
29 |
+
self.multiple = multiple
|
30 |
+
|
31 |
+
def _get_pad(self, size):
|
32 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
33 |
+
pad_size = new_size - size
|
34 |
+
pad_size_left = pad_size // 2
|
35 |
+
pad_size_right = pad_size - pad_size_left
|
36 |
+
return pad_size_left, pad_size_right
|
37 |
+
|
38 |
+
@torch.inference_mode()
|
39 |
+
def forward(self, x):
|
40 |
+
pads = list(
|
41 |
+
itertools.chain.from_iterable(
|
42 |
+
self._get_pad(m) for m in x.shape[:1:-1]
|
43 |
+
)
|
44 |
+
)
|
45 |
+
output = F.pad(x, pads)
|
46 |
+
return output
|
hubconf.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
from sinder.hub.backbones import (
|
8 |
+
dinov2_vitb14,
|
9 |
+
dinov2_vitb14_reg,
|
10 |
+
dinov2_vitg14,
|
11 |
+
dinov2_vitg14_reg,
|
12 |
+
dinov2_vitl14,
|
13 |
+
dinov2_vitl14_reg,
|
14 |
+
dinov2_vits14,
|
15 |
+
dinov2_vits14_reg,
|
16 |
+
)
|
17 |
+
|
18 |
+
dependencies = ['torch']
|
19 |
+
|
20 |
+
__all__ = [
|
21 |
+
'dinov2_vitb14',
|
22 |
+
'dinov2_vitb14_reg',
|
23 |
+
'dinov2_vitg14',
|
24 |
+
'dinov2_vitg14_reg',
|
25 |
+
'dinov2_vitl14',
|
26 |
+
'dinov2_vitl14_reg',
|
27 |
+
'dinov2_vits14',
|
28 |
+
'dinov2_vits14_reg',
|
29 |
+
]
|
main.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from sinder import (
|
13 |
+
get_neighbor_loss,
|
14 |
+
get_tokens,
|
15 |
+
load_data,
|
16 |
+
load_model,
|
17 |
+
load_visual_data,
|
18 |
+
pca_array,
|
19 |
+
replace_back,
|
20 |
+
replace_linear_addition_noqk,
|
21 |
+
)
|
22 |
+
|
23 |
+
os.environ['XFORMERS_DISABLED'] = '1'
|
24 |
+
torch.set_float32_matmul_precision('high')
|
25 |
+
|
26 |
+
|
27 |
+
def parse_args():
|
28 |
+
parser = argparse.ArgumentParser(description='Beautify network')
|
29 |
+
parser.add_argument(
|
30 |
+
'--model', type=str, default='dinov2_vitg14', help='config file'
|
31 |
+
)
|
32 |
+
parser.add_argument('--work_dir', type=str, default='results')
|
33 |
+
parser.add_argument('--resolution', type=int, default=518)
|
34 |
+
parser.add_argument('--lr', type=float, default=0.005)
|
35 |
+
parser.add_argument('--max_iter', type=int, default=30000)
|
36 |
+
parser.add_argument('--num_train_max', type=int, default=30000)
|
37 |
+
parser.add_argument('--mask_thr', type=float, default=4)
|
38 |
+
parser.add_argument('--skip_less_than', type=int, default=3)
|
39 |
+
parser.add_argument('--visual_size', type=int, default=448 * 2)
|
40 |
+
parser.add_argument('--kernel', type=int, default=3)
|
41 |
+
parser.add_argument('--save_at_skip', type=int, nargs='+', default=[75])
|
42 |
+
parser.add_argument('--limit_layers', type=int, default=10)
|
43 |
+
|
44 |
+
args = parser.parse_args()
|
45 |
+
return args
|
46 |
+
|
47 |
+
|
48 |
+
def prepare_train(args, model):
|
49 |
+
model.train()
|
50 |
+
|
51 |
+
all_params = []
|
52 |
+
for name, param in model.named_parameters():
|
53 |
+
param.requires_grad = False
|
54 |
+
|
55 |
+
replace_linear_addition_noqk(model, 'model')
|
56 |
+
for name, param in model.named_parameters():
|
57 |
+
if '.epsilon' in name and param.requires_grad is True:
|
58 |
+
all_params.append(param)
|
59 |
+
|
60 |
+
grad_params = []
|
61 |
+
for name, param in model.named_parameters():
|
62 |
+
if param.requires_grad:
|
63 |
+
grad_params.append(name)
|
64 |
+
|
65 |
+
assert len(grad_params) == len(all_params)
|
66 |
+
print(len(grad_params), grad_params)
|
67 |
+
print(len(all_params), all_params)
|
68 |
+
optimizer = torch.optim.SGD(
|
69 |
+
all_params,
|
70 |
+
lr=args.lr,
|
71 |
+
momentum=0.9,
|
72 |
+
)
|
73 |
+
|
74 |
+
return optimizer
|
75 |
+
|
76 |
+
|
77 |
+
def save_model(args, model):
|
78 |
+
print('save model')
|
79 |
+
model.eval()
|
80 |
+
|
81 |
+
replace_back(model, 'model')
|
82 |
+
|
83 |
+
torch.save(model.state_dict(), args.folder / 'model.pt')
|
84 |
+
|
85 |
+
|
86 |
+
def train(args, model, dataset, optimizer, visual_dataset):
|
87 |
+
print('training')
|
88 |
+
skip_history = [False] * 1000
|
89 |
+
model.train()
|
90 |
+
|
91 |
+
for global_iter in tqdm(range(args.max_iter)):
|
92 |
+
img = dataset[global_iter % len(dataset)]
|
93 |
+
H = img.shape[1] // model.patch_size
|
94 |
+
W = img.shape[2] // model.patch_size
|
95 |
+
density = np.array(skip_history[-1000:]).astype(float).mean()
|
96 |
+
print(f'{global_iter=} {W=} {H=} {density=:.2f}')
|
97 |
+
|
98 |
+
for percent in args.save_at_skip:
|
99 |
+
if percent / 100 <= density:
|
100 |
+
print(f'save checkpoint at {density=}')
|
101 |
+
args.save_at_skip.remove(percent)
|
102 |
+
torch.save(model, args.folder / f'checkpoint_p{percent}.pth')
|
103 |
+
if len(args.save_at_skip) == 0:
|
104 |
+
break
|
105 |
+
|
106 |
+
model.zero_grad()
|
107 |
+
|
108 |
+
model.train()
|
109 |
+
with torch.enable_grad():
|
110 |
+
image_batch = img.unsqueeze(0).cuda()
|
111 |
+
result = get_neighbor_loss(
|
112 |
+
model,
|
113 |
+
image_batch,
|
114 |
+
skip_less_than=args.skip_less_than,
|
115 |
+
mask_thr=args.mask_thr,
|
116 |
+
kernel=args.kernel,
|
117 |
+
)
|
118 |
+
|
119 |
+
if result is None:
|
120 |
+
skip_history.append(True)
|
121 |
+
print('no loss, skip')
|
122 |
+
else:
|
123 |
+
skip_history.append(False)
|
124 |
+
(
|
125 |
+
layer,
|
126 |
+
loss,
|
127 |
+
I,
|
128 |
+
J,
|
129 |
+
T,
|
130 |
+
alpha,
|
131 |
+
mask_angle,
|
132 |
+
x_token,
|
133 |
+
) = result
|
134 |
+
|
135 |
+
print(
|
136 |
+
f'{global_iter=}, {layer=}, {density=}, {alpha=:.2f}, {len(I)=}, '
|
137 |
+
f'{loss.item()=:.2f}'
|
138 |
+
)
|
139 |
+
|
140 |
+
if torch.isnan(loss).any():
|
141 |
+
print('nan loss, skip')
|
142 |
+
continue
|
143 |
+
loss.backward()
|
144 |
+
|
145 |
+
# set some grad to 0
|
146 |
+
if args.limit_layers:
|
147 |
+
with torch.no_grad():
|
148 |
+
for t in range(layer - args.limit_layers + 1):
|
149 |
+
for p in model.blocks[t].parameters():
|
150 |
+
p.grad = None
|
151 |
+
|
152 |
+
has_nan = False
|
153 |
+
for name, param in model.named_parameters():
|
154 |
+
if param.grad is not None and torch.isnan(param.grad).any():
|
155 |
+
print(f'nan grad at {name}, skip')
|
156 |
+
has_nan = True
|
157 |
+
if has_nan:
|
158 |
+
continue
|
159 |
+
optimizer.step()
|
160 |
+
|
161 |
+
# visualize
|
162 |
+
if global_iter % 100 == 0:
|
163 |
+
try:
|
164 |
+
print(f'visualization at {global_iter=}')
|
165 |
+
pca_img = pca_array(x_token)
|
166 |
+
pca_img.save(args.folder / 'pca.png')
|
167 |
+
mask_img = Image.fromarray(
|
168 |
+
(mask_angle * 255)
|
169 |
+
.detach()
|
170 |
+
.cpu()
|
171 |
+
.numpy()
|
172 |
+
.astype(np.uint8)
|
173 |
+
).resize((W * 7, H * 7), resample=Image.NEAREST)
|
174 |
+
mask_img.save(args.folder / 'mask.png')
|
175 |
+
Image.fromarray(
|
176 |
+
(
|
177 |
+
(
|
178 |
+
img.permute((1, 2, 0)).cpu().numpy() * 0.22
|
179 |
+
+ 0.45
|
180 |
+
)
|
181 |
+
* 255
|
182 |
+
)
|
183 |
+
.clip(0, 255)
|
184 |
+
.astype(np.uint8)
|
185 |
+
).save(args.folder / 'img.png')
|
186 |
+
if global_iter % 1000 == 0:
|
187 |
+
pca_img.save(args.folder / f'{global_iter:05}_pca.png')
|
188 |
+
mask_img.save(
|
189 |
+
args.folder / f'{global_iter:05}_mask.png'
|
190 |
+
)
|
191 |
+
Image.fromarray(
|
192 |
+
(
|
193 |
+
(
|
194 |
+
img.permute((1, 2, 0)).cpu().numpy() * 0.22
|
195 |
+
+ 0.45
|
196 |
+
)
|
197 |
+
* 255
|
198 |
+
)
|
199 |
+
.clip(0, 255)
|
200 |
+
.astype(np.uint8)
|
201 |
+
).save(args.folder / f'{global_iter:05}_img.png')
|
202 |
+
except Exception as e:
|
203 |
+
print(e)
|
204 |
+
|
205 |
+
for d in range(len(visual_dataset)):
|
206 |
+
visual_image = visual_dataset[d]
|
207 |
+
visual_tokens_all = get_tokens(model, visual_image)
|
208 |
+
visual_tokens, visual_tokens_cls = zip(*visual_tokens_all)
|
209 |
+
pca_img = pca_array(visual_tokens[-1])
|
210 |
+
pca_img.save(args.folder / f'{d}_pca.png')
|
211 |
+
if global_iter % 500 == 0:
|
212 |
+
pca_img.save(
|
213 |
+
args.folder / f'{global_iter:05}_{d}_pca.png'
|
214 |
+
)
|
215 |
+
|
216 |
+
torch.save(model, args.folder / 'checkpoint.pth')
|
217 |
+
|
218 |
+
|
219 |
+
def main():
|
220 |
+
print('Start beautify')
|
221 |
+
args = parse_args()
|
222 |
+
|
223 |
+
name = f'res{args.resolution}_lr{args.lr}_{args.num_train_max}_skipless{args.skip_less_than}_maskthr{args.mask_thr}_limit{args.limit_layers}_ker{args.kernel}'
|
224 |
+
args.folder = Path(args.work_dir) / name
|
225 |
+
os.makedirs(args.folder, exist_ok=True)
|
226 |
+
print(args)
|
227 |
+
print(' '.join(sys.argv))
|
228 |
+
print(f'work dir {args.folder}')
|
229 |
+
model = load_model(args.model)
|
230 |
+
dataset = load_data(args, model)
|
231 |
+
visual_dataset = load_visual_data(args, model)
|
232 |
+
optimizer = prepare_train(args, model)
|
233 |
+
train(args, model, dataset, optimizer, visual_dataset)
|
234 |
+
save_model(args, model)
|
235 |
+
|
236 |
+
|
237 |
+
if __name__ == '__main__':
|
238 |
+
main()
|
models/dinov2/__init__.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from . import vision_transformer as vits
|
9 |
+
|
10 |
+
logger = logging.getLogger('dinov2')
|
11 |
+
|
12 |
+
|
13 |
+
def build_model(args, only_teacher=False, img_size=224):
|
14 |
+
args.arch = args.arch.removesuffix('_memeff')
|
15 |
+
if 'vit' in args.arch:
|
16 |
+
vit_kwargs = dict(
|
17 |
+
img_size=img_size,
|
18 |
+
patch_size=args.patch_size,
|
19 |
+
init_values=args.layerscale,
|
20 |
+
ffn_layer=args.ffn_layer,
|
21 |
+
block_chunks=args.block_chunks,
|
22 |
+
qkv_bias=args.qkv_bias,
|
23 |
+
proj_bias=args.proj_bias,
|
24 |
+
ffn_bias=args.ffn_bias,
|
25 |
+
num_register_tokens=args.num_register_tokens,
|
26 |
+
interpolate_offset=args.interpolate_offset,
|
27 |
+
interpolate_antialias=args.interpolate_antialias,
|
28 |
+
)
|
29 |
+
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
30 |
+
if only_teacher:
|
31 |
+
return teacher, teacher.embed_dim
|
32 |
+
student = vits.__dict__[args.arch](
|
33 |
+
**vit_kwargs,
|
34 |
+
drop_path_rate=args.drop_path_rate,
|
35 |
+
drop_path_uniform=args.drop_path_uniform,
|
36 |
+
)
|
37 |
+
embed_dim = student.embed_dim
|
38 |
+
return student, teacher, embed_dim
|
39 |
+
|
40 |
+
|
41 |
+
def build_model_from_cfg(cfg, only_teacher=False):
|
42 |
+
return build_model(
|
43 |
+
cfg.student,
|
44 |
+
only_teacher=only_teacher,
|
45 |
+
img_size=cfg.crops.global_crops_size,
|
46 |
+
)
|
models/dinov2/layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .attention import MemEffAttention
|
7 |
+
from .block import NestedTensorBlock
|
8 |
+
from .dino_head import DINOHead
|
9 |
+
from .mlp import Mlp
|
10 |
+
from .patch_embed import PatchEmbed
|
11 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
models/dinov2/layers/attention.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
from torch import Tensor, nn
|
15 |
+
|
16 |
+
logger = logging.getLogger('dinov2')
|
17 |
+
|
18 |
+
|
19 |
+
XFORMERS_ENABLED = os.environ.get('XFORMERS_DISABLED') is None
|
20 |
+
try:
|
21 |
+
if XFORMERS_ENABLED:
|
22 |
+
from xformers.ops import memory_efficient_attention, unbind
|
23 |
+
|
24 |
+
XFORMERS_AVAILABLE = True
|
25 |
+
warnings.warn('xFormers is available (Attention)')
|
26 |
+
else:
|
27 |
+
raise ImportError
|
28 |
+
except ImportError:
|
29 |
+
XFORMERS_AVAILABLE = False
|
30 |
+
|
31 |
+
|
32 |
+
class Attention(nn.Module):
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
dim: int,
|
36 |
+
num_heads: int = 8,
|
37 |
+
qkv_bias: bool = False,
|
38 |
+
proj_bias: bool = True,
|
39 |
+
attn_drop: float = 0.0,
|
40 |
+
proj_drop: float = 0.0,
|
41 |
+
) -> None:
|
42 |
+
super().__init__()
|
43 |
+
self.num_heads = num_heads
|
44 |
+
head_dim = dim // num_heads
|
45 |
+
self.scale = head_dim**-0.5
|
46 |
+
|
47 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
48 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
49 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
50 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
51 |
+
|
52 |
+
def forward(self, x: Tensor) -> Tensor:
|
53 |
+
B, N, C = x.shape
|
54 |
+
qkv = (
|
55 |
+
self.qkv(x)
|
56 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
57 |
+
.permute(2, 0, 3, 1, 4)
|
58 |
+
)
|
59 |
+
|
60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
61 |
+
attn = q @ k.transpose(-2, -1)
|
62 |
+
|
63 |
+
attn = attn.softmax(dim=-1)
|
64 |
+
attn = self.attn_drop(attn)
|
65 |
+
|
66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
67 |
+
x = self.proj(x)
|
68 |
+
x = self.proj_drop(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class MemEffAttention(Attention):
|
73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
74 |
+
if not XFORMERS_AVAILABLE:
|
75 |
+
if attn_bias is not None:
|
76 |
+
raise AssertionError(
|
77 |
+
'xFormers is required for using nested tensors'
|
78 |
+
)
|
79 |
+
return super().forward(x)
|
80 |
+
|
81 |
+
B, N, C = x.shape
|
82 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
83 |
+
|
84 |
+
q, k, v = unbind(qkv, 2)
|
85 |
+
|
86 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
87 |
+
x = x.reshape([B, N, C])
|
88 |
+
|
89 |
+
x = self.proj(x)
|
90 |
+
x = self.proj_drop(x)
|
91 |
+
return x
|
models/dinov2/layers/block.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
from typing import Any, Callable, Dict, List, Tuple
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import Tensor, nn
|
17 |
+
|
18 |
+
from .attention import Attention, MemEffAttention
|
19 |
+
from .drop_path import DropPath
|
20 |
+
from .layer_scale import LayerScale
|
21 |
+
from .mlp import Mlp
|
22 |
+
|
23 |
+
logger = logging.getLogger('dinov2')
|
24 |
+
|
25 |
+
|
26 |
+
XFORMERS_ENABLED = os.environ.get('XFORMERS_DISABLED') is None
|
27 |
+
try:
|
28 |
+
if XFORMERS_ENABLED:
|
29 |
+
from xformers.ops import fmha, index_select_cat, scaled_index_add
|
30 |
+
|
31 |
+
XFORMERS_AVAILABLE = True
|
32 |
+
warnings.warn('xFormers is available (Block)')
|
33 |
+
else:
|
34 |
+
raise ImportError
|
35 |
+
except ImportError:
|
36 |
+
XFORMERS_AVAILABLE = False
|
37 |
+
|
38 |
+
|
39 |
+
class Block(nn.Module):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
dim: int,
|
43 |
+
num_heads: int,
|
44 |
+
mlp_ratio: float = 4.0,
|
45 |
+
qkv_bias: bool = False,
|
46 |
+
proj_bias: bool = True,
|
47 |
+
ffn_bias: bool = True,
|
48 |
+
drop: float = 0.0,
|
49 |
+
attn_drop: float = 0.0,
|
50 |
+
init_values=None,
|
51 |
+
drop_path: float = 0.0,
|
52 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
53 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
54 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
55 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
56 |
+
) -> None:
|
57 |
+
super().__init__()
|
58 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
59 |
+
self.norm1 = norm_layer(dim)
|
60 |
+
self.attn = attn_class(
|
61 |
+
dim,
|
62 |
+
num_heads=num_heads,
|
63 |
+
qkv_bias=qkv_bias,
|
64 |
+
proj_bias=proj_bias,
|
65 |
+
attn_drop=attn_drop,
|
66 |
+
proj_drop=drop,
|
67 |
+
)
|
68 |
+
self.ls1 = (
|
69 |
+
LayerScale(dim, init_values=init_values)
|
70 |
+
if init_values
|
71 |
+
else nn.Identity()
|
72 |
+
)
|
73 |
+
self.drop_path1 = (
|
74 |
+
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
75 |
+
)
|
76 |
+
|
77 |
+
self.norm2 = norm_layer(dim)
|
78 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
79 |
+
self.mlp = ffn_layer(
|
80 |
+
in_features=dim,
|
81 |
+
hidden_features=mlp_hidden_dim,
|
82 |
+
act_layer=act_layer,
|
83 |
+
drop=drop,
|
84 |
+
bias=ffn_bias,
|
85 |
+
)
|
86 |
+
self.ls2 = (
|
87 |
+
LayerScale(dim, init_values=init_values)
|
88 |
+
if init_values
|
89 |
+
else nn.Identity()
|
90 |
+
)
|
91 |
+
self.drop_path2 = (
|
92 |
+
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
93 |
+
)
|
94 |
+
|
95 |
+
self.sample_drop_ratio = drop_path
|
96 |
+
|
97 |
+
def forward(self, x: Tensor) -> Tensor:
|
98 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
99 |
+
return self.ls1(self.attn(self.norm1(x)))
|
100 |
+
|
101 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
102 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
103 |
+
|
104 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
105 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
106 |
+
x = drop_add_residual_stochastic_depth(
|
107 |
+
x,
|
108 |
+
residual_func=attn_residual_func,
|
109 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
110 |
+
)
|
111 |
+
x = drop_add_residual_stochastic_depth(
|
112 |
+
x,
|
113 |
+
residual_func=ffn_residual_func,
|
114 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
115 |
+
)
|
116 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
117 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
118 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
119 |
+
else:
|
120 |
+
x = x + attn_residual_func(x)
|
121 |
+
x = x + ffn_residual_func(x)
|
122 |
+
return x
|
123 |
+
|
124 |
+
|
125 |
+
def drop_add_residual_stochastic_depth(
|
126 |
+
x: Tensor,
|
127 |
+
residual_func: Callable[[Tensor], Tensor],
|
128 |
+
sample_drop_ratio: float = 0.0,
|
129 |
+
) -> Tensor:
|
130 |
+
# 1) extract subset using permutation
|
131 |
+
b, n, d = x.shape
|
132 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
133 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
134 |
+
x_subset = x[brange]
|
135 |
+
|
136 |
+
# 2) apply residual_func to get residual
|
137 |
+
residual = residual_func(x_subset)
|
138 |
+
|
139 |
+
x_flat = x.flatten(1)
|
140 |
+
residual = residual.flatten(1)
|
141 |
+
|
142 |
+
residual_scale_factor = b / sample_subset_size
|
143 |
+
|
144 |
+
# 3) add the residual
|
145 |
+
x_plus_residual = torch.index_add(
|
146 |
+
x_flat,
|
147 |
+
0,
|
148 |
+
brange,
|
149 |
+
residual.to(dtype=x.dtype),
|
150 |
+
alpha=residual_scale_factor,
|
151 |
+
)
|
152 |
+
return x_plus_residual.view_as(x)
|
153 |
+
|
154 |
+
|
155 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
156 |
+
b, n, d = x.shape
|
157 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
158 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
159 |
+
residual_scale_factor = b / sample_subset_size
|
160 |
+
return brange, residual_scale_factor
|
161 |
+
|
162 |
+
|
163 |
+
def add_residual(
|
164 |
+
x, brange, residual, residual_scale_factor, scaling_vector=None
|
165 |
+
):
|
166 |
+
if scaling_vector is None:
|
167 |
+
x_flat = x.flatten(1)
|
168 |
+
residual = residual.flatten(1)
|
169 |
+
x_plus_residual = torch.index_add(
|
170 |
+
x_flat,
|
171 |
+
0,
|
172 |
+
brange,
|
173 |
+
residual.to(dtype=x.dtype),
|
174 |
+
alpha=residual_scale_factor,
|
175 |
+
)
|
176 |
+
else:
|
177 |
+
x_plus_residual = scaled_index_add(
|
178 |
+
x,
|
179 |
+
brange,
|
180 |
+
residual.to(dtype=x.dtype),
|
181 |
+
scaling=scaling_vector,
|
182 |
+
alpha=residual_scale_factor,
|
183 |
+
)
|
184 |
+
return x_plus_residual
|
185 |
+
|
186 |
+
|
187 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
188 |
+
|
189 |
+
|
190 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
191 |
+
"""This will perform the index select, cat the tensors, and provide the
|
192 |
+
attn_bias from cache."""
|
193 |
+
batch_sizes = (
|
194 |
+
[b.shape[0] for b in branges]
|
195 |
+
if branges is not None
|
196 |
+
else [x.shape[0] for x in x_list]
|
197 |
+
)
|
198 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
199 |
+
if all_shapes not in attn_bias_cache.keys():
|
200 |
+
seqlens = []
|
201 |
+
for b, x in zip(batch_sizes, x_list):
|
202 |
+
for _ in range(b):
|
203 |
+
seqlens.append(x.shape[1])
|
204 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
205 |
+
attn_bias._batch_sizes = batch_sizes
|
206 |
+
attn_bias_cache[all_shapes] = attn_bias
|
207 |
+
|
208 |
+
if branges is not None:
|
209 |
+
cat_tensors = index_select_cat(
|
210 |
+
[x.flatten(1) for x in x_list], branges
|
211 |
+
).view(1, -1, x_list[0].shape[-1])
|
212 |
+
else:
|
213 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
214 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
215 |
+
|
216 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
217 |
+
|
218 |
+
|
219 |
+
def drop_add_residual_stochastic_depth_list(
|
220 |
+
x_list: List[Tensor],
|
221 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
222 |
+
sample_drop_ratio: float = 0.0,
|
223 |
+
scaling_vector=None,
|
224 |
+
) -> Tensor:
|
225 |
+
# 1) generate random set of indices for dropping samples in the batch
|
226 |
+
branges_scales = [
|
227 |
+
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio)
|
228 |
+
for x in x_list
|
229 |
+
]
|
230 |
+
branges = [s[0] for s in branges_scales]
|
231 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
232 |
+
|
233 |
+
# 2) get attention bias and index+concat the tensors
|
234 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
235 |
+
|
236 |
+
# 3) apply residual_func to get residual, and split the result
|
237 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
238 |
+
|
239 |
+
outputs = []
|
240 |
+
for x, brange, residual, residual_scale_factor in zip(
|
241 |
+
x_list, branges, residual_list, residual_scale_factors
|
242 |
+
):
|
243 |
+
outputs.append(
|
244 |
+
add_residual(
|
245 |
+
x, brange, residual, residual_scale_factor, scaling_vector
|
246 |
+
).view_as(x)
|
247 |
+
)
|
248 |
+
return outputs
|
249 |
+
|
250 |
+
|
251 |
+
class NestedTensorBlock(Block):
|
252 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
253 |
+
"""x_list contains a list of tensors to nest together and run."""
|
254 |
+
assert isinstance(self.attn, MemEffAttention)
|
255 |
+
|
256 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
257 |
+
|
258 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
259 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
260 |
+
|
261 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
262 |
+
return self.mlp(self.norm2(x))
|
263 |
+
|
264 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
265 |
+
x_list,
|
266 |
+
residual_func=attn_residual_func,
|
267 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
268 |
+
scaling_vector=self.ls1.gamma
|
269 |
+
if isinstance(self.ls1, LayerScale)
|
270 |
+
else None,
|
271 |
+
)
|
272 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
273 |
+
x_list,
|
274 |
+
residual_func=ffn_residual_func,
|
275 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
276 |
+
scaling_vector=self.ls2.gamma
|
277 |
+
if isinstance(self.ls1, LayerScale)
|
278 |
+
else None,
|
279 |
+
)
|
280 |
+
return x_list
|
281 |
+
else:
|
282 |
+
|
283 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
284 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
285 |
+
|
286 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
287 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
288 |
+
|
289 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
290 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
291 |
+
x = x + ffn_residual_func(x)
|
292 |
+
return attn_bias.split(x)
|
293 |
+
|
294 |
+
def forward(self, x_or_x_list):
|
295 |
+
if isinstance(x_or_x_list, Tensor):
|
296 |
+
return super().forward(x_or_x_list)
|
297 |
+
elif isinstance(x_or_x_list, list):
|
298 |
+
if not XFORMERS_AVAILABLE:
|
299 |
+
raise AssertionError(
|
300 |
+
'xFormers is required for using nested tensors'
|
301 |
+
)
|
302 |
+
return self.forward_nested(x_or_x_list)
|
303 |
+
else:
|
304 |
+
raise AssertionError
|
models/dinov2/layers/dino_head.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn.init import trunc_normal_
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
|
12 |
+
class DINOHead(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
in_dim,
|
16 |
+
out_dim,
|
17 |
+
use_bn=False,
|
18 |
+
nlayers=3,
|
19 |
+
hidden_dim=2048,
|
20 |
+
bottleneck_dim=256,
|
21 |
+
mlp_bias=True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
nlayers = max(nlayers, 1)
|
25 |
+
self.mlp = _build_mlp(
|
26 |
+
nlayers,
|
27 |
+
in_dim,
|
28 |
+
bottleneck_dim,
|
29 |
+
hidden_dim=hidden_dim,
|
30 |
+
use_bn=use_bn,
|
31 |
+
bias=mlp_bias,
|
32 |
+
)
|
33 |
+
self.apply(self._init_weights)
|
34 |
+
self.last_layer = weight_norm(
|
35 |
+
nn.Linear(bottleneck_dim, out_dim, bias=False)
|
36 |
+
)
|
37 |
+
self.last_layer.weight_g.data.fill_(1)
|
38 |
+
|
39 |
+
def _init_weights(self, m):
|
40 |
+
if isinstance(m, nn.Linear):
|
41 |
+
trunc_normal_(m.weight, std=0.02)
|
42 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
43 |
+
nn.init.constant_(m.bias, 0)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
x = self.mlp(x)
|
47 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
48 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
49 |
+
x = self.last_layer(x)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
def _build_mlp(
|
54 |
+
nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
|
55 |
+
):
|
56 |
+
if nlayers == 1:
|
57 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
58 |
+
else:
|
59 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
60 |
+
if use_bn:
|
61 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
62 |
+
layers.append(nn.GELU())
|
63 |
+
for _ in range(nlayers - 2):
|
64 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
65 |
+
if use_bn:
|
66 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
67 |
+
layers.append(nn.GELU())
|
68 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
69 |
+
return nn.Sequential(*layers)
|
models/dinov2/layers/drop_path.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
9 |
+
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
15 |
+
if drop_prob == 0.0 or not training:
|
16 |
+
return x
|
17 |
+
keep_prob = 1 - drop_prob
|
18 |
+
shape = (x.shape[0],) + (1,) * (
|
19 |
+
x.ndim - 1
|
20 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
21 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
22 |
+
if keep_prob > 0.0:
|
23 |
+
random_tensor.div_(keep_prob)
|
24 |
+
output = x * random_tensor
|
25 |
+
return output
|
26 |
+
|
27 |
+
|
28 |
+
class DropPath(nn.Module):
|
29 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
30 |
+
residual blocks)."""
|
31 |
+
|
32 |
+
def __init__(self, drop_prob=None):
|
33 |
+
super().__init__()
|
34 |
+
self.drop_prob = drop_prob
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return drop_path(x, self.drop_prob, self.training)
|
models/dinov2/layers/layer_scale.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
7 |
+
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor, nn
|
12 |
+
|
13 |
+
|
14 |
+
class LayerScale(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
dim: int,
|
18 |
+
init_values: Union[float, Tensor] = 1e-5,
|
19 |
+
inplace: bool = False,
|
20 |
+
) -> None:
|
21 |
+
super().__init__()
|
22 |
+
self.inplace = inplace
|
23 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
24 |
+
|
25 |
+
def forward(self, x: Tensor) -> Tensor:
|
26 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
models/dinov2/layers/mlp.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
9 |
+
|
10 |
+
|
11 |
+
from typing import Callable, Optional
|
12 |
+
|
13 |
+
from torch import Tensor, nn
|
14 |
+
|
15 |
+
|
16 |
+
class Mlp(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
in_features: int,
|
20 |
+
hidden_features: Optional[int] = None,
|
21 |
+
out_features: Optional[int] = None,
|
22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
23 |
+
drop: float = 0.0,
|
24 |
+
bias: bool = True,
|
25 |
+
) -> None:
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
30 |
+
self.act = act_layer()
|
31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
32 |
+
self.drop = nn.Dropout(drop)
|
33 |
+
|
34 |
+
def forward(self, x: Tensor) -> Tensor:
|
35 |
+
x = self.fc1(x)
|
36 |
+
x = self.act(x)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.fc2(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
return x
|
models/dinov2/layers/patch_embed.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
from typing import Callable, Optional, Tuple, Union
|
11 |
+
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch import Tensor
|
14 |
+
|
15 |
+
|
16 |
+
def make_2tuple(x):
|
17 |
+
if isinstance(x, tuple):
|
18 |
+
assert len(x) == 2
|
19 |
+
return x
|
20 |
+
|
21 |
+
assert isinstance(x, int)
|
22 |
+
return (x, x)
|
23 |
+
|
24 |
+
|
25 |
+
class PatchEmbed(nn.Module):
|
26 |
+
"""
|
27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
28 |
+
|
29 |
+
Args:
|
30 |
+
img_size: Image size.
|
31 |
+
patch_size: Patch token size.
|
32 |
+
in_chans: Number of input image channels.
|
33 |
+
embed_dim: Number of linear projection output channels.
|
34 |
+
norm_layer: Normalization layer.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
41 |
+
in_chans: int = 3,
|
42 |
+
embed_dim: int = 768,
|
43 |
+
norm_layer: Optional[Callable] = None,
|
44 |
+
flatten_embedding: bool = True,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
image_HW = make_2tuple(img_size)
|
49 |
+
patch_HW = make_2tuple(patch_size)
|
50 |
+
patch_grid_size = (
|
51 |
+
image_HW[0] // patch_HW[0],
|
52 |
+
image_HW[1] // patch_HW[1],
|
53 |
+
)
|
54 |
+
|
55 |
+
self.img_size = image_HW
|
56 |
+
self.patch_size = patch_HW
|
57 |
+
self.patches_resolution = patch_grid_size
|
58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
59 |
+
|
60 |
+
self.in_chans = in_chans
|
61 |
+
self.embed_dim = embed_dim
|
62 |
+
|
63 |
+
self.flatten_embedding = flatten_embedding
|
64 |
+
|
65 |
+
self.proj = nn.Conv2d(
|
66 |
+
in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
|
67 |
+
)
|
68 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
69 |
+
|
70 |
+
def forward(self, x: Tensor) -> Tensor:
|
71 |
+
_, _, H, W = x.shape
|
72 |
+
patch_H, patch_W = self.patch_size
|
73 |
+
|
74 |
+
assert (
|
75 |
+
H % patch_H == 0
|
76 |
+
), f'Input image height {H} is not a multiple of patch height {patch_H}'
|
77 |
+
assert (
|
78 |
+
W % patch_W == 0
|
79 |
+
), f'Input image width {W} is not a multiple of patch width: {patch_W}'
|
80 |
+
|
81 |
+
x = self.proj(x) # B C H W
|
82 |
+
H, W = x.size(2), x.size(3)
|
83 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
84 |
+
x = self.norm(x)
|
85 |
+
if not self.flatten_embedding:
|
86 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
87 |
+
return x
|
88 |
+
|
89 |
+
def flops(self) -> float:
|
90 |
+
Ho, Wo = self.patches_resolution
|
91 |
+
flops = (
|
92 |
+
Ho
|
93 |
+
* Wo
|
94 |
+
* self.embed_dim
|
95 |
+
* self.in_chans
|
96 |
+
* (self.patch_size[0] * self.patch_size[1])
|
97 |
+
)
|
98 |
+
if self.norm is not None:
|
99 |
+
flops += Ho * Wo * self.embed_dim
|
100 |
+
return flops
|
models/dinov2/layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import warnings
|
8 |
+
from typing import Callable, Optional
|
9 |
+
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch import Tensor, nn
|
12 |
+
|
13 |
+
|
14 |
+
class SwiGLUFFN(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_features: int,
|
18 |
+
hidden_features: Optional[int] = None,
|
19 |
+
out_features: Optional[int] = None,
|
20 |
+
act_layer: Callable[..., nn.Module] = None,
|
21 |
+
drop: float = 0.0,
|
22 |
+
bias: bool = True,
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
out_features = out_features or in_features
|
26 |
+
hidden_features = hidden_features or in_features
|
27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
x12 = self.w12(x)
|
32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
33 |
+
hidden = F.silu(x1) * x2
|
34 |
+
return self.w3(hidden)
|
35 |
+
|
36 |
+
|
37 |
+
XFORMERS_ENABLED = os.environ.get('XFORMERS_DISABLED') is None
|
38 |
+
try:
|
39 |
+
if XFORMERS_ENABLED:
|
40 |
+
from xformers.ops import SwiGLU
|
41 |
+
|
42 |
+
XFORMERS_AVAILABLE = True
|
43 |
+
warnings.warn('xFormers is available (SwiGLU)')
|
44 |
+
else:
|
45 |
+
raise ImportError
|
46 |
+
except ImportError:
|
47 |
+
SwiGLU = SwiGLUFFN
|
48 |
+
XFORMERS_AVAILABLE = False
|
49 |
+
|
50 |
+
|
51 |
+
class SwiGLUFFNFused(SwiGLU):
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
in_features: int,
|
55 |
+
hidden_features: Optional[int] = None,
|
56 |
+
out_features: Optional[int] = None,
|
57 |
+
act_layer: Callable[..., nn.Module] = None,
|
58 |
+
drop: float = 0.0,
|
59 |
+
bias: bool = True,
|
60 |
+
) -> None:
|
61 |
+
out_features = out_features or in_features
|
62 |
+
hidden_features = hidden_features or in_features
|
63 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
64 |
+
super().__init__(
|
65 |
+
in_features=in_features,
|
66 |
+
hidden_features=hidden_features,
|
67 |
+
out_features=out_features,
|
68 |
+
bias=bias,
|
69 |
+
)
|
models/dinov2/vision_transformer.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import math
|
12 |
+
from functools import partial
|
13 |
+
from typing import Callable, Sequence, Tuple, Union
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from .layers import MemEffAttention, Mlp
|
21 |
+
from .layers import NestedTensorBlock as Block
|
22 |
+
from .layers import PatchEmbed, SwiGLUFFNFused
|
23 |
+
|
24 |
+
logger = logging.getLogger('dinov2')
|
25 |
+
|
26 |
+
|
27 |
+
def named_apply(
|
28 |
+
fn: Callable,
|
29 |
+
module: nn.Module,
|
30 |
+
name='',
|
31 |
+
depth_first=True,
|
32 |
+
include_root=False,
|
33 |
+
) -> nn.Module:
|
34 |
+
if not depth_first and include_root:
|
35 |
+
fn(module=module, name=name)
|
36 |
+
for child_name, child_module in module.named_children():
|
37 |
+
child_name = '.'.join((name, child_name)) if name else child_name
|
38 |
+
named_apply(
|
39 |
+
fn=fn,
|
40 |
+
module=child_module,
|
41 |
+
name=child_name,
|
42 |
+
depth_first=depth_first,
|
43 |
+
include_root=True,
|
44 |
+
)
|
45 |
+
if depth_first and include_root:
|
46 |
+
fn(module=module, name=name)
|
47 |
+
return module
|
48 |
+
|
49 |
+
|
50 |
+
class BlockChunk(nn.ModuleList):
|
51 |
+
def forward(self, x):
|
52 |
+
for b in self:
|
53 |
+
x = b(x)
|
54 |
+
return x
|
55 |
+
|
56 |
+
|
57 |
+
class DinoVisionTransformer(nn.Module):
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
img_size=224,
|
61 |
+
patch_size=16,
|
62 |
+
in_chans=3,
|
63 |
+
embed_dim=768,
|
64 |
+
depth=12,
|
65 |
+
num_heads=12,
|
66 |
+
mlp_ratio=4.0,
|
67 |
+
qkv_bias=True,
|
68 |
+
ffn_bias=True,
|
69 |
+
proj_bias=True,
|
70 |
+
drop_path_rate=0.0,
|
71 |
+
drop_path_uniform=False,
|
72 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
73 |
+
embed_layer=PatchEmbed,
|
74 |
+
act_layer=nn.GELU,
|
75 |
+
block_fn=Block,
|
76 |
+
ffn_layer='mlp',
|
77 |
+
block_chunks=1,
|
78 |
+
num_register_tokens=0,
|
79 |
+
interpolate_antialias=False,
|
80 |
+
interpolate_offset=0.1,
|
81 |
+
):
|
82 |
+
"""
|
83 |
+
Args:
|
84 |
+
img_size (int, tuple): input image size
|
85 |
+
patch_size (int, tuple): patch size
|
86 |
+
in_chans (int): number of input channels
|
87 |
+
embed_dim (int): embedding dimension
|
88 |
+
depth (int): depth of transformer
|
89 |
+
num_heads (int): number of attention heads
|
90 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
91 |
+
qkv_bias (bool): enable bias for qkv if True
|
92 |
+
proj_bias (bool): enable bias for proj in attn if True
|
93 |
+
ffn_bias (bool): enable bias for ffn if True
|
94 |
+
drop_path_rate (float): stochastic depth rate
|
95 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
96 |
+
weight_init (str): weight init scheme
|
97 |
+
init_values (float): layer-scale init values
|
98 |
+
embed_layer (nn.Module): patch embedding layer
|
99 |
+
act_layer (nn.Module): MLP activation layer
|
100 |
+
block_fn (nn.Module): transformer block class
|
101 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
102 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
103 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
104 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
105 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
106 |
+
"""
|
107 |
+
super().__init__()
|
108 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
109 |
+
|
110 |
+
self.num_features = (
|
111 |
+
self.embed_dim
|
112 |
+
) = embed_dim # num_features for consistency with other models
|
113 |
+
self.num_tokens = 1
|
114 |
+
self.n_blocks = depth
|
115 |
+
self.num_heads = num_heads
|
116 |
+
self.patch_size = patch_size
|
117 |
+
self.num_register_tokens = num_register_tokens
|
118 |
+
self.interpolate_antialias = interpolate_antialias
|
119 |
+
self.interpolate_offset = interpolate_offset
|
120 |
+
|
121 |
+
self.patch_embed = embed_layer(
|
122 |
+
img_size=img_size,
|
123 |
+
patch_size=patch_size,
|
124 |
+
in_chans=in_chans,
|
125 |
+
embed_dim=embed_dim,
|
126 |
+
)
|
127 |
+
num_patches = self.patch_embed.num_patches
|
128 |
+
|
129 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
130 |
+
self.pos_embed = nn.Parameter(
|
131 |
+
torch.zeros(1, num_patches + self.num_tokens, embed_dim)
|
132 |
+
)
|
133 |
+
assert num_register_tokens >= 0
|
134 |
+
self.register_tokens = (
|
135 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
|
136 |
+
if num_register_tokens
|
137 |
+
else None
|
138 |
+
)
|
139 |
+
|
140 |
+
if drop_path_uniform is True:
|
141 |
+
dpr = [drop_path_rate] * depth
|
142 |
+
else:
|
143 |
+
dpr = [
|
144 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
145 |
+
] # stochastic depth decay rule
|
146 |
+
|
147 |
+
if ffn_layer == 'mlp':
|
148 |
+
logger.info('using MLP layer as FFN')
|
149 |
+
ffn_layer = Mlp
|
150 |
+
elif ffn_layer == 'swiglufused' or ffn_layer == 'swiglu':
|
151 |
+
logger.info('using SwiGLU layer as FFN')
|
152 |
+
ffn_layer = SwiGLUFFNFused
|
153 |
+
elif ffn_layer == 'identity':
|
154 |
+
logger.info('using Identity layer as FFN')
|
155 |
+
|
156 |
+
def f(*args, **kwargs):
|
157 |
+
return nn.Identity()
|
158 |
+
|
159 |
+
ffn_layer = f
|
160 |
+
else:
|
161 |
+
raise NotImplementedError
|
162 |
+
|
163 |
+
blocks_list = [
|
164 |
+
block_fn(
|
165 |
+
dim=embed_dim,
|
166 |
+
num_heads=num_heads,
|
167 |
+
mlp_ratio=mlp_ratio,
|
168 |
+
qkv_bias=qkv_bias,
|
169 |
+
proj_bias=proj_bias,
|
170 |
+
ffn_bias=ffn_bias,
|
171 |
+
drop_path=dpr[i],
|
172 |
+
norm_layer=norm_layer,
|
173 |
+
act_layer=act_layer,
|
174 |
+
ffn_layer=ffn_layer,
|
175 |
+
init_values=init_values,
|
176 |
+
)
|
177 |
+
for i in range(depth)
|
178 |
+
]
|
179 |
+
if block_chunks > 0:
|
180 |
+
self.chunked_blocks = True
|
181 |
+
chunked_blocks = []
|
182 |
+
chunksize = depth // block_chunks
|
183 |
+
for i in range(0, depth, chunksize):
|
184 |
+
# this is to keep the block index consistent if we chunk the block list
|
185 |
+
chunked_blocks.append(
|
186 |
+
[nn.Identity()] * i + blocks_list[i : i + chunksize]
|
187 |
+
)
|
188 |
+
self.blocks = nn.ModuleList(
|
189 |
+
[BlockChunk(p) for p in chunked_blocks]
|
190 |
+
)
|
191 |
+
else:
|
192 |
+
self.chunked_blocks = False
|
193 |
+
self.blocks = nn.ModuleList(blocks_list)
|
194 |
+
|
195 |
+
self.norm = norm_layer(embed_dim)
|
196 |
+
self.head = nn.Identity()
|
197 |
+
|
198 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
199 |
+
|
200 |
+
self.init_weights()
|
201 |
+
|
202 |
+
def init_weights(self):
|
203 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
204 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
205 |
+
if self.register_tokens is not None:
|
206 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
207 |
+
named_apply(init_weights_vit_timm, self)
|
208 |
+
|
209 |
+
def interpolate_pos_encoding(self, x, w, h):
|
210 |
+
previous_dtype = x.dtype
|
211 |
+
npatch = x.shape[1] - 1
|
212 |
+
N = self.pos_embed.shape[1] - 1
|
213 |
+
if npatch == N and w == h:
|
214 |
+
return self.pos_embed
|
215 |
+
pos_embed = self.pos_embed.float()
|
216 |
+
class_pos_embed = pos_embed[:, 0]
|
217 |
+
patch_pos_embed = pos_embed[:, 1:]
|
218 |
+
dim = x.shape[-1]
|
219 |
+
w0 = w // self.patch_size
|
220 |
+
h0 = h // self.patch_size
|
221 |
+
M = int(
|
222 |
+
math.sqrt(N)
|
223 |
+
) # Recover the number of patches in each dimension
|
224 |
+
assert N == M * M
|
225 |
+
kwargs = {}
|
226 |
+
if self.interpolate_offset:
|
227 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
228 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
229 |
+
sx = float(w0 + self.interpolate_offset) / M
|
230 |
+
sy = float(h0 + self.interpolate_offset) / M
|
231 |
+
kwargs['scale_factor'] = (sx, sy)
|
232 |
+
else:
|
233 |
+
# Simply specify an output size instead of a scale factor
|
234 |
+
kwargs['size'] = (w0, h0)
|
235 |
+
patch_pos_embed = nn.functional.interpolate(
|
236 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
237 |
+
mode='bicubic',
|
238 |
+
antialias=self.interpolate_antialias,
|
239 |
+
**kwargs,
|
240 |
+
)
|
241 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
242 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
243 |
+
return torch.cat(
|
244 |
+
(class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1
|
245 |
+
).to(previous_dtype)
|
246 |
+
|
247 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
248 |
+
B, nc, w, h = x.shape
|
249 |
+
x = self.patch_embed(x)
|
250 |
+
if masks is not None:
|
251 |
+
x = torch.where(
|
252 |
+
masks.unsqueeze(-1),
|
253 |
+
self.mask_token.to(x.dtype).unsqueeze(0),
|
254 |
+
x,
|
255 |
+
)
|
256 |
+
|
257 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
258 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
259 |
+
|
260 |
+
if self.register_tokens is not None:
|
261 |
+
x = torch.cat(
|
262 |
+
(
|
263 |
+
x[:, :1],
|
264 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
265 |
+
x[:, 1:],
|
266 |
+
),
|
267 |
+
dim=1,
|
268 |
+
)
|
269 |
+
|
270 |
+
return x
|
271 |
+
|
272 |
+
def forward_features_list(self, x_list, masks_list):
|
273 |
+
x = [
|
274 |
+
self.prepare_tokens_with_masks(x, masks)
|
275 |
+
for x, masks in zip(x_list, masks_list)
|
276 |
+
]
|
277 |
+
for blk in self.blocks:
|
278 |
+
x = blk(x)
|
279 |
+
|
280 |
+
all_x = x
|
281 |
+
output = []
|
282 |
+
for x, masks in zip(all_x, masks_list):
|
283 |
+
x_norm = self.norm(x)
|
284 |
+
output.append(
|
285 |
+
{
|
286 |
+
'x_norm_clstoken': x_norm[:, 0],
|
287 |
+
'x_norm_regtokens': x_norm[
|
288 |
+
:, 1 : self.num_register_tokens + 1
|
289 |
+
],
|
290 |
+
'x_norm_patchtokens': x_norm[
|
291 |
+
:, self.num_register_tokens + 1 :
|
292 |
+
],
|
293 |
+
'x_prenorm': x,
|
294 |
+
'masks': masks,
|
295 |
+
}
|
296 |
+
)
|
297 |
+
return output
|
298 |
+
|
299 |
+
def forward_features(self, x, masks=None):
|
300 |
+
if isinstance(x, list):
|
301 |
+
return self.forward_features_list(x, masks)
|
302 |
+
|
303 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
304 |
+
|
305 |
+
for blk in self.blocks:
|
306 |
+
x = blk(x)
|
307 |
+
|
308 |
+
x_norm = self.norm(x)
|
309 |
+
return {
|
310 |
+
'x_norm_clstoken': x_norm[:, 0],
|
311 |
+
'x_norm_regtokens': x_norm[:, 1 : self.num_register_tokens + 1],
|
312 |
+
'x_norm_patchtokens': x_norm[:, self.num_register_tokens + 1 :],
|
313 |
+
'x_prenorm': x,
|
314 |
+
'masks': masks,
|
315 |
+
}
|
316 |
+
|
317 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
318 |
+
x = self.prepare_tokens_with_masks(x)
|
319 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
320 |
+
output, total_block_len = [], len(self.blocks)
|
321 |
+
blocks_to_take = (
|
322 |
+
range(total_block_len - n, total_block_len)
|
323 |
+
if isinstance(n, int)
|
324 |
+
else n
|
325 |
+
)
|
326 |
+
for i, blk in enumerate(self.blocks):
|
327 |
+
x = blk(x)
|
328 |
+
if i in blocks_to_take:
|
329 |
+
output.append(x)
|
330 |
+
assert len(output) == len(
|
331 |
+
blocks_to_take
|
332 |
+
), f'only {len(output)} / {len(blocks_to_take)} blocks found'
|
333 |
+
return output
|
334 |
+
|
335 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
336 |
+
x = self.prepare_tokens_with_masks(x)
|
337 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
338 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
339 |
+
blocks_to_take = (
|
340 |
+
range(total_block_len - n, total_block_len)
|
341 |
+
if isinstance(n, int)
|
342 |
+
else n
|
343 |
+
)
|
344 |
+
for block_chunk in self.blocks:
|
345 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
346 |
+
x = blk(x)
|
347 |
+
if i in blocks_to_take:
|
348 |
+
output.append(x)
|
349 |
+
i += 1
|
350 |
+
assert len(output) == len(
|
351 |
+
blocks_to_take
|
352 |
+
), f'only {len(output)} / {len(blocks_to_take)} blocks found'
|
353 |
+
return output
|
354 |
+
|
355 |
+
def get_intermediate_layers(
|
356 |
+
self,
|
357 |
+
x: torch.Tensor,
|
358 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
359 |
+
reshape: bool = False,
|
360 |
+
return_class_token: bool = False,
|
361 |
+
norm=True,
|
362 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
363 |
+
if self.chunked_blocks:
|
364 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
365 |
+
else:
|
366 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
367 |
+
if norm:
|
368 |
+
outputs = [self.norm(out) for out in outputs]
|
369 |
+
class_tokens = [out[:, 0] for out in outputs]
|
370 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
371 |
+
if reshape:
|
372 |
+
B, _, w, h = x.shape
|
373 |
+
outputs = [
|
374 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
|
375 |
+
.permute(0, 3, 1, 2)
|
376 |
+
.contiguous()
|
377 |
+
for out in outputs
|
378 |
+
]
|
379 |
+
if return_class_token:
|
380 |
+
return tuple(zip(outputs, class_tokens))
|
381 |
+
return tuple(outputs)
|
382 |
+
|
383 |
+
def forward(self, *args, is_training=False, **kwargs):
|
384 |
+
ret = self.forward_features(*args, **kwargs)
|
385 |
+
if is_training:
|
386 |
+
return ret
|
387 |
+
else:
|
388 |
+
return self.head(ret['x_norm_clstoken'])
|
389 |
+
|
390 |
+
|
391 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ''):
|
392 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
393 |
+
if isinstance(module, nn.Linear):
|
394 |
+
trunc_normal_(module.weight, std=0.02)
|
395 |
+
if module.bias is not None:
|
396 |
+
nn.init.zeros_(module.bias)
|
397 |
+
|
398 |
+
|
399 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
400 |
+
model = DinoVisionTransformer(
|
401 |
+
patch_size=patch_size,
|
402 |
+
embed_dim=384,
|
403 |
+
depth=12,
|
404 |
+
num_heads=6,
|
405 |
+
mlp_ratio=4,
|
406 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
407 |
+
num_register_tokens=num_register_tokens,
|
408 |
+
**kwargs,
|
409 |
+
)
|
410 |
+
return model
|
411 |
+
|
412 |
+
|
413 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
414 |
+
model = DinoVisionTransformer(
|
415 |
+
patch_size=patch_size,
|
416 |
+
embed_dim=768,
|
417 |
+
depth=12,
|
418 |
+
num_heads=12,
|
419 |
+
mlp_ratio=4,
|
420 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
421 |
+
num_register_tokens=num_register_tokens,
|
422 |
+
**kwargs,
|
423 |
+
)
|
424 |
+
return model
|
425 |
+
|
426 |
+
|
427 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
428 |
+
model = DinoVisionTransformer(
|
429 |
+
patch_size=patch_size,
|
430 |
+
embed_dim=1024,
|
431 |
+
depth=24,
|
432 |
+
num_heads=16,
|
433 |
+
mlp_ratio=4,
|
434 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
435 |
+
num_register_tokens=num_register_tokens,
|
436 |
+
**kwargs,
|
437 |
+
)
|
438 |
+
return model
|
439 |
+
|
440 |
+
|
441 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
442 |
+
"""Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per
|
443 |
+
head 64."""
|
444 |
+
model = DinoVisionTransformer(
|
445 |
+
patch_size=patch_size,
|
446 |
+
embed_dim=1536,
|
447 |
+
depth=40,
|
448 |
+
num_heads=24,
|
449 |
+
mlp_ratio=4,
|
450 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
451 |
+
num_register_tokens=num_register_tokens,
|
452 |
+
**kwargs,
|
453 |
+
)
|
454 |
+
return model
|
neighbor_loss.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.nn.functional import normalize
|
3 |
+
|
4 |
+
|
5 |
+
def check_anomaly_theoretical(
|
6 |
+
x,
|
7 |
+
H,
|
8 |
+
W,
|
9 |
+
anomaly_dir=None,
|
10 |
+
temperature=0.1,
|
11 |
+
mask_thr=0.001,
|
12 |
+
kernel=3,
|
13 |
+
):
|
14 |
+
x_token = x[:, 1:]
|
15 |
+
B = x.shape[0]
|
16 |
+
assert B == 1
|
17 |
+
x_token = x_token.reshape(H, W, -1).contiguous()
|
18 |
+
|
19 |
+
with torch.no_grad():
|
20 |
+
feature = normalize(x_token, dim=-1)
|
21 |
+
direction = normalize(anomaly_dir, dim=-1)
|
22 |
+
|
23 |
+
logits = -(feature * direction).sum(dim=-1).abs()
|
24 |
+
prob = torch.exp(logits / temperature)
|
25 |
+
|
26 |
+
assert kernel in (3, 5)
|
27 |
+
pad = kernel // 2
|
28 |
+
|
29 |
+
w = prob.unfold(0, kernel, 1).unfold(1, kernel, 1)
|
30 |
+
w = w / w.sum(dim=(-1, -2), keepdims=True)
|
31 |
+
|
32 |
+
if kernel == 3:
|
33 |
+
gaussian = (
|
34 |
+
torch.FloatTensor(
|
35 |
+
[
|
36 |
+
1 / 16,
|
37 |
+
1 / 8,
|
38 |
+
1 / 16,
|
39 |
+
1 / 8,
|
40 |
+
1 / 4,
|
41 |
+
1 / 8,
|
42 |
+
1 / 16,
|
43 |
+
1 / 8,
|
44 |
+
1 / 16,
|
45 |
+
]
|
46 |
+
)
|
47 |
+
.to(w.device)
|
48 |
+
.reshape(1, 1, 3, 3)
|
49 |
+
)
|
50 |
+
elif kernel == 5:
|
51 |
+
gaussian = (
|
52 |
+
torch.tensor(
|
53 |
+
[
|
54 |
+
[1, 4, 7, 4, 1],
|
55 |
+
[4, 16, 26, 16, 4],
|
56 |
+
[7, 26, 41, 26, 7],
|
57 |
+
[4, 16, 26, 16, 4],
|
58 |
+
[1, 4, 7, 4, 1],
|
59 |
+
]
|
60 |
+
)
|
61 |
+
.float()
|
62 |
+
.to(w.device)
|
63 |
+
/ 273
|
64 |
+
)
|
65 |
+
|
66 |
+
w2 = w * gaussian
|
67 |
+
|
68 |
+
w2 = w2 / w2.sum(dim=(-1, -2), keepdims=True)
|
69 |
+
|
70 |
+
T = x_token.unfold(0, kernel, 1).unfold(1, kernel, 1)
|
71 |
+
T = (T * w2[:, :, None].to(T.device)).sum(dim=(-1, -2))
|
72 |
+
|
73 |
+
mask_full = logits < logits.mean() - mask_thr * logits.std()
|
74 |
+
mask_full[:pad, :] = False
|
75 |
+
mask_full[:, :pad] = False
|
76 |
+
mask_full[-pad:, :] = False
|
77 |
+
mask_full[:, -pad:] = False
|
78 |
+
index_tensor = torch.nonzero(mask_full.flatten()).flatten()
|
79 |
+
if len(index_tensor) == 0:
|
80 |
+
return None
|
81 |
+
rows = index_tensor // W
|
82 |
+
cols = index_tensor % W
|
83 |
+
|
84 |
+
alpha = x_token[pad:-pad, pad:-pad].norm(dim=-1).mean()
|
85 |
+
|
86 |
+
loss_neighbor = (
|
87 |
+
(x_token[rows, cols] - T[rows - pad, cols - pad]).norm(dim=-1)
|
88 |
+
).mean() / alpha
|
89 |
+
|
90 |
+
return loss_neighbor, rows, cols, T, alpha, mask_full, x_token
|
91 |
+
|
92 |
+
|
93 |
+
def get_neighbor_loss(
|
94 |
+
model,
|
95 |
+
x,
|
96 |
+
skip_less_than=1,
|
97 |
+
mask_thr=0.001,
|
98 |
+
kernel=3,
|
99 |
+
):
|
100 |
+
H = x.shape[2]
|
101 |
+
W = x.shape[3]
|
102 |
+
x = model.prepare_tokens_with_masks(x)
|
103 |
+
|
104 |
+
for i, blk in enumerate(model.blocks):
|
105 |
+
x = blk(x)
|
106 |
+
assert len(model.singular_defects) > 0
|
107 |
+
result = check_anomaly_theoretical(
|
108 |
+
x,
|
109 |
+
H // model.patch_size,
|
110 |
+
W // model.patch_size,
|
111 |
+
model.singular_defects[i],
|
112 |
+
mask_thr=mask_thr,
|
113 |
+
kernel=kernel,
|
114 |
+
)
|
115 |
+
if result is not None:
|
116 |
+
(
|
117 |
+
loss_neighbor,
|
118 |
+
rows,
|
119 |
+
cols,
|
120 |
+
T,
|
121 |
+
alpha,
|
122 |
+
mask_angle,
|
123 |
+
x_token,
|
124 |
+
) = result
|
125 |
+
if len(rows) >= skip_less_than:
|
126 |
+
assert not torch.isnan(loss_neighbor).any()
|
127 |
+
return (
|
128 |
+
i,
|
129 |
+
loss_neighbor,
|
130 |
+
rows,
|
131 |
+
cols,
|
132 |
+
T,
|
133 |
+
alpha,
|
134 |
+
mask_angle,
|
135 |
+
x_token,
|
136 |
+
)
|
137 |
+
return None
|
pyproject.toml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=68", "setuptools_scm[toml]>=8"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "sinder"
|
7 |
+
requires-python = ">=3.8"
|
8 |
+
dynamic = ["version"]
|
9 |
+
dependencies = [
|
10 |
+
# Add runtime dependencies here
|
11 |
+
]
|
12 |
+
|
13 |
+
# Enables the usage of setuptools_scm
|
14 |
+
[tool.setuptools_scm]
|
15 |
+
|
16 |
+
[tool.setuptools]
|
17 |
+
py-modules = []
|
repair.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class SVDLinearAddition(torch.nn.Module):
|
5 |
+
def __init__(self, linear, *args, **kwargs) -> None:
|
6 |
+
super().__init__(*args, **kwargs)
|
7 |
+
w, b = linear.weight, linear.bias
|
8 |
+
|
9 |
+
self.bias = b
|
10 |
+
|
11 |
+
with torch.no_grad():
|
12 |
+
U, S, Vt = torch.linalg.svd(w, full_matrices=False)
|
13 |
+
self.U = torch.nn.Parameter(U)
|
14 |
+
self.S = torch.nn.Parameter(S)
|
15 |
+
self.Vt = torch.nn.Parameter(Vt)
|
16 |
+
self.epsilon = torch.nn.Parameter(torch.zeros_like(S))
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
w = self.U @ torch.diag(self.S + self.epsilon) @ self.Vt
|
20 |
+
x = torch.nn.functional.linear(x, w, self.bias)
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
class SVDLinearAdditionQKV(torch.nn.Module):
|
25 |
+
def __init__(self, linear, *args, **kwargs) -> None:
|
26 |
+
super().__init__(*args, **kwargs)
|
27 |
+
w, b = linear.weight, linear.bias
|
28 |
+
|
29 |
+
self.bias = b
|
30 |
+
|
31 |
+
with torch.no_grad():
|
32 |
+
c = w.shape[0] // 3
|
33 |
+
q = w[:c]
|
34 |
+
k = w[c : 2 * c]
|
35 |
+
v = w[2 * c :]
|
36 |
+
U_q, S_q, Vt_q = torch.linalg.svd(q, full_matrices=False)
|
37 |
+
self.U_q = torch.nn.Parameter(U_q)
|
38 |
+
self.S_q = torch.nn.Parameter(S_q)
|
39 |
+
self.Vt_q = torch.nn.Parameter(Vt_q)
|
40 |
+
self.epsilon_q = torch.nn.Parameter(torch.zeros_like(S_q))
|
41 |
+
U_k, S_k, Vt_k = torch.linalg.svd(k, full_matrices=False)
|
42 |
+
self.U_k = torch.nn.Parameter(U_k)
|
43 |
+
self.S_k = torch.nn.Parameter(S_k)
|
44 |
+
self.Vt_k = torch.nn.Parameter(Vt_k)
|
45 |
+
self.epsilon_k = torch.nn.Parameter(torch.zeros_like(S_k))
|
46 |
+
U_v, S_v, Vt_v = torch.linalg.svd(v, full_matrices=False)
|
47 |
+
self.U_v = torch.nn.Parameter(U_v)
|
48 |
+
self.S_v = torch.nn.Parameter(S_v)
|
49 |
+
self.Vt_v = torch.nn.Parameter(Vt_v)
|
50 |
+
self.epsilon_v = torch.nn.Parameter(torch.zeros_like(S_v))
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
w = torch.concatenate(
|
54 |
+
(
|
55 |
+
self.U_q @ torch.diag(self.S_q + self.epsilon_q) @ self.Vt_q,
|
56 |
+
self.U_k @ torch.diag(self.S_k + self.epsilon_k) @ self.Vt_k,
|
57 |
+
self.U_v @ torch.diag(self.S_v + self.epsilon_v) @ self.Vt_v,
|
58 |
+
)
|
59 |
+
)
|
60 |
+
x = torch.nn.functional.linear(x, w, self.bias)
|
61 |
+
return x
|
62 |
+
|
63 |
+
|
64 |
+
def replace_linear_addition_noqk(module, name):
|
65 |
+
for attr_str in dir(module):
|
66 |
+
target_attr = getattr(module, attr_str)
|
67 |
+
if type(target_attr) == torch.nn.Linear:
|
68 |
+
if 'qkv' in attr_str:
|
69 |
+
print('replaced: ', name, attr_str)
|
70 |
+
svd_linear_qkv = SVDLinearAdditionQKV(target_attr)
|
71 |
+
svd_linear_qkv.U_q.requires_grad = False
|
72 |
+
svd_linear_qkv.U_k.requires_grad = False
|
73 |
+
svd_linear_qkv.U_v.requires_grad = False
|
74 |
+
svd_linear_qkv.S_q.requires_grad = False
|
75 |
+
svd_linear_qkv.S_k.requires_grad = False
|
76 |
+
svd_linear_qkv.S_v.requires_grad = False
|
77 |
+
svd_linear_qkv.Vt_q.requires_grad = False
|
78 |
+
svd_linear_qkv.Vt_k.requires_grad = False
|
79 |
+
svd_linear_qkv.Vt_v.requires_grad = False
|
80 |
+
svd_linear_qkv.epsilon_q.requires_grad = False
|
81 |
+
svd_linear_qkv.epsilon_k.requires_grad = False
|
82 |
+
svd_linear_qkv.epsilon_v.requires_grad = True
|
83 |
+
svd_linear_qkv.bias.requires_grad = False
|
84 |
+
setattr(module, attr_str, svd_linear_qkv)
|
85 |
+
else:
|
86 |
+
print('replaced: ', name, attr_str)
|
87 |
+
svd_linear = SVDLinearAddition(target_attr)
|
88 |
+
svd_linear.U.requires_grad = False
|
89 |
+
svd_linear.S.requires_grad = False
|
90 |
+
svd_linear.Vt.requires_grad = False
|
91 |
+
svd_linear.bias.requires_grad = False
|
92 |
+
svd_linear.epsilon.requires_grad = True
|
93 |
+
setattr(module, attr_str, svd_linear)
|
94 |
+
|
95 |
+
for name, immediate_child_module in module.named_children():
|
96 |
+
replace_linear_addition_noqk(immediate_child_module, name)
|
97 |
+
|
98 |
+
|
99 |
+
def replace_back(module, name):
|
100 |
+
for attr_str in dir(module):
|
101 |
+
target_attr = getattr(module, attr_str)
|
102 |
+
|
103 |
+
if type(target_attr) == SVDLinearAddition:
|
104 |
+
print('replaced back: ', name, attr_str)
|
105 |
+
with torch.no_grad():
|
106 |
+
linear = torch.nn.Linear(
|
107 |
+
target_attr.Vt.shape[1],
|
108 |
+
target_attr.U.shape[0],
|
109 |
+
device=target_attr.U.device,
|
110 |
+
)
|
111 |
+
linear.weight.add_(
|
112 |
+
target_attr.U
|
113 |
+
@ torch.diag(target_attr.S + target_attr.epsilon)
|
114 |
+
@ target_attr.Vt
|
115 |
+
- linear.weight
|
116 |
+
)
|
117 |
+
linear.bias.add_(target_attr.bias - linear.bias)
|
118 |
+
|
119 |
+
setattr(module, attr_str, linear)
|
120 |
+
|
121 |
+
elif type(target_attr) == SVDLinearAdditionQKV:
|
122 |
+
print('replaced back: ', name, attr_str)
|
123 |
+
with torch.no_grad():
|
124 |
+
w = torch.concatenate(
|
125 |
+
(
|
126 |
+
target_attr.U_q
|
127 |
+
@ torch.diag(target_attr.S_q + target_attr.epsilon_q)
|
128 |
+
@ target_attr.Vt_q,
|
129 |
+
target_attr.U_k
|
130 |
+
@ torch.diag(target_attr.S_k + target_attr.epsilon_k)
|
131 |
+
@ target_attr.Vt_k,
|
132 |
+
target_attr.U_v
|
133 |
+
@ torch.diag(target_attr.S_v + target_attr.epsilon_v)
|
134 |
+
@ target_attr.Vt_v,
|
135 |
+
)
|
136 |
+
)
|
137 |
+
linear = torch.nn.Linear(
|
138 |
+
w.shape[1], w.shape[0], device=target_attr.U_q.device
|
139 |
+
)
|
140 |
+
linear.weight.add_(w - linear.weight)
|
141 |
+
linear.bias.add_(target_attr.bias - linear.bias)
|
142 |
+
|
143 |
+
setattr(module, attr_str, linear)
|
144 |
+
|
145 |
+
# iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
|
146 |
+
for name, immediate_child_module in module.named_children():
|
147 |
+
replace_back(immediate_child_module, name)
|
setup.cfg
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[bdist_wheel]
|
2 |
+
universal=1
|
3 |
+
|
4 |
+
[yapf]
|
5 |
+
based_on_style = pep8
|
6 |
+
blank_line_before_nested_class_or_def = true
|
7 |
+
split_before_expression_after_opening_paren = true
|
8 |
+
allow_split_before_dict_value = False
|
9 |
+
split_penalty_import_names=0
|
10 |
+
SPLIT_PENALTY_AFTER_OPENING_BRACKET=0
|
11 |
+
|
12 |
+
[isort]
|
13 |
+
line_length = 79
|
14 |
+
extra_standard_library = pkg_resources,setuptools,logging,os,warnings,abc
|
15 |
+
known_first_party =
|
16 |
+
known_third_party = addict,cv2,matplotlib,numpy,onnx,packaging,pytest,pytorch_sphinx_theme,scipy,sphinx,torch,torchvision,yaml,yapf
|
17 |
+
no_lines_before = STDLIB,LOCALFOLDER
|
18 |
+
default_section = THIRDPARTY
|
19 |
+
profile = black
|
20 |
+
|
21 |
+
[flake8]
|
22 |
+
extend-ignore = E501,E203,E722,E266,E402,E251
|
23 |
+
exclude = .git,__pycache__
|
singular_defect.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# compute singular defect directions
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def anomaly_dir_attn(
|
7 |
+
blk,
|
8 |
+
identity=False,
|
9 |
+
bias=False,
|
10 |
+
centered=False,
|
11 |
+
homogeneous=False,
|
12 |
+
):
|
13 |
+
with torch.no_grad():
|
14 |
+
N = blk.ls1.gamma.shape[0]
|
15 |
+
dev = blk.ls1.gamma.device
|
16 |
+
|
17 |
+
A4 = torch.diag(blk.ls1.gamma)
|
18 |
+
A3 = blk.attn.proj.weight
|
19 |
+
B3 = blk.attn.proj.bias
|
20 |
+
A2 = blk.attn.qkv.weight.chunk(3, dim=0)[-1]
|
21 |
+
B2 = blk.attn.qkv.bias.chunk(3, dim=0)[-1]
|
22 |
+
A1 = torch.diag(blk.norm1.weight)
|
23 |
+
B1 = blk.norm1.bias
|
24 |
+
A0 = (torch.eye(N) - 1 / N * torch.ones(N, N)).to(dev)
|
25 |
+
A = A4 @ A3 @ A2 @ A1
|
26 |
+
|
27 |
+
if centered:
|
28 |
+
A = A @ A0
|
29 |
+
B = A4 @ (A3 @ (A2 @ B1)) + A4 @ (A3 @ B2) + A4 @ B3
|
30 |
+
|
31 |
+
if bias:
|
32 |
+
A = torch.cat((A, B[:, None]), dim=1)
|
33 |
+
if homogeneous:
|
34 |
+
onehot = torch.cat(
|
35 |
+
(torch.zeros_like(B), torch.ones(1).to(dev))
|
36 |
+
)
|
37 |
+
A = torch.cat((A, onehot[None]), dim=0)
|
38 |
+
|
39 |
+
if identity:
|
40 |
+
iden = torch.eye(N).to(dev)
|
41 |
+
A[:N, :N] += iden
|
42 |
+
u, _, _ = torch.linalg.svd(A)
|
43 |
+
|
44 |
+
return u[:N, 0], A, B
|
45 |
+
|
46 |
+
|
47 |
+
def w12(blk, x):
|
48 |
+
with torch.no_grad():
|
49 |
+
x1, x2 = blk.mlp.w12(x).chunk(2, dim=-1)
|
50 |
+
return F.silu(x1) * x2
|
51 |
+
|
52 |
+
|
53 |
+
def anomaly_dir_mlp_ls(
|
54 |
+
blk,
|
55 |
+
identity=False,
|
56 |
+
bias=False,
|
57 |
+
centered=False,
|
58 |
+
homogeneous=False,
|
59 |
+
bias_ls=False,
|
60 |
+
):
|
61 |
+
with torch.no_grad():
|
62 |
+
N = blk.ls2.gamma.shape[0]
|
63 |
+
M = blk.mlp.w3.weight.shape[1]
|
64 |
+
dev = blk.ls2.gamma.device
|
65 |
+
|
66 |
+
A4 = torch.diag(blk.ls2.gamma)
|
67 |
+
A3 = blk.mlp.w3.weight
|
68 |
+
B3 = blk.mlp.w3.bias
|
69 |
+
|
70 |
+
X = torch.randn(100000, N, device=dev)
|
71 |
+
Y = w12(blk, X)
|
72 |
+
if bias_ls:
|
73 |
+
X_one = torch.cat((X, torch.ones(100000, 1).to(dev)), dim=1)
|
74 |
+
else:
|
75 |
+
X_one = X
|
76 |
+
sol = torch.linalg.lstsq(X_one, Y)
|
77 |
+
if bias_ls:
|
78 |
+
A2 = sol.solution.T[:, :-1]
|
79 |
+
B2 = sol.solution.T[:, -1]
|
80 |
+
else:
|
81 |
+
A2 = sol.solution.T
|
82 |
+
B2 = torch.zeros(M).to(dev)
|
83 |
+
|
84 |
+
A1 = torch.diag(blk.norm2.weight)
|
85 |
+
B1 = blk.norm2.bias
|
86 |
+
A0 = (torch.eye(N) - 1 / N * torch.ones(N, N)).to(dev)
|
87 |
+
A = A4 @ A3 @ A2 @ A1
|
88 |
+
|
89 |
+
if centered:
|
90 |
+
A = A @ A0
|
91 |
+
B = A4 @ (A3 @ (A2 @ B1)) + A4 @ (A3 @ B2) + A4 @ B3
|
92 |
+
|
93 |
+
if bias:
|
94 |
+
A = torch.cat((A, B[:, None]), dim=1)
|
95 |
+
if homogeneous:
|
96 |
+
onehot = torch.cat(
|
97 |
+
(torch.zeros_like(B), torch.ones(1).to(dev))
|
98 |
+
)
|
99 |
+
A = torch.cat((A, onehot[None]), dim=0)
|
100 |
+
|
101 |
+
if identity:
|
102 |
+
iden = torch.eye(N).to(dev)
|
103 |
+
A[:N, :N] += iden
|
104 |
+
u, s, vt = torch.linalg.svd(A)
|
105 |
+
|
106 |
+
return u[:N, 0], A, B
|
107 |
+
|
108 |
+
|
109 |
+
def anomaly_dir(blk, homogeneous=False):
|
110 |
+
_, A, b = anomaly_dir_attn(
|
111 |
+
blk,
|
112 |
+
identity=True,
|
113 |
+
bias=homogeneous,
|
114 |
+
centered=True,
|
115 |
+
homogeneous=homogeneous,
|
116 |
+
)
|
117 |
+
_, C, d = anomaly_dir_mlp_ls(
|
118 |
+
blk,
|
119 |
+
identity=True,
|
120 |
+
bias=homogeneous,
|
121 |
+
bias_ls=False,
|
122 |
+
centered=True,
|
123 |
+
homogeneous=homogeneous,
|
124 |
+
)
|
125 |
+
|
126 |
+
with torch.no_grad():
|
127 |
+
N = b.shape[0]
|
128 |
+
AA = C @ A
|
129 |
+
if homogeneous:
|
130 |
+
BB = 0
|
131 |
+
else:
|
132 |
+
BB = C @ b + d
|
133 |
+
u, _, _ = torch.linalg.svd(AA)
|
134 |
+
|
135 |
+
return u[:N, 0], AA, BB
|
136 |
+
|
137 |
+
|
138 |
+
def singular_defect_directions(model):
|
139 |
+
accumulative_anomalies = []
|
140 |
+
anomaly_dab = [anomaly_dir(blk) for blk in model.blocks]
|
141 |
+
anomaly_as = [dab[1] for dab in anomaly_dab]
|
142 |
+
|
143 |
+
with torch.no_grad():
|
144 |
+
aaa = torch.eye(anomaly_as[0].shape[0]).to(anomaly_as[0])
|
145 |
+
for a in anomaly_as:
|
146 |
+
aaa = a @ aaa
|
147 |
+
u, _, _ = torch.linalg.svd(aaa)
|
148 |
+
accumulative_anomalies.append(u[:, 0])
|
149 |
+
return accumulative_anomalies
|
utils.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
from sklearn.decomposition import PCA
|
6 |
+
|
7 |
+
import sinder
|
8 |
+
from .singular_defect import singular_defect_directions
|
9 |
+
|
10 |
+
|
11 |
+
def pca_array(tokens, whiten=False):
|
12 |
+
h, w, c = tokens.shape
|
13 |
+
tokens = tokens.detach().cpu()
|
14 |
+
|
15 |
+
pca = PCA(n_components=3, whiten=whiten)
|
16 |
+
pca.fit(tokens.reshape(-1, c))
|
17 |
+
projected_tokens = pca.transform(tokens.reshape(-1, c))
|
18 |
+
|
19 |
+
t = torch.tensor(projected_tokens)
|
20 |
+
t_min = t.min(dim=0, keepdim=True).values
|
21 |
+
t_max = t.max(dim=0, keepdim=True).values
|
22 |
+
normalized_t = (t - t_min) / (t_max - t_min)
|
23 |
+
|
24 |
+
array = (normalized_t * 255).byte().numpy()
|
25 |
+
array = array.reshape(h, w, 3)
|
26 |
+
|
27 |
+
return Image.fromarray(array).resize((w * 7, h * 7), 0)
|
28 |
+
|
29 |
+
|
30 |
+
def get_tokens(model, image, blocks=1):
|
31 |
+
model.eval()
|
32 |
+
with torch.no_grad():
|
33 |
+
image_batch = image.unsqueeze(0).cuda()
|
34 |
+
image_batch = image_batch.cuda()
|
35 |
+
H = image_batch.shape[2]
|
36 |
+
W = image_batch.shape[3]
|
37 |
+
print(f'{W=} {H=}')
|
38 |
+
tokens = model.get_intermediate_layers(
|
39 |
+
image_batch, blocks, return_class_token=True, norm=False
|
40 |
+
)
|
41 |
+
tokens = [
|
42 |
+
(
|
43 |
+
t.reshape(
|
44 |
+
(H // model.patch_size, W // model.patch_size, t.size(-1))
|
45 |
+
),
|
46 |
+
tc,
|
47 |
+
)
|
48 |
+
for t, tc in tokens
|
49 |
+
]
|
50 |
+
|
51 |
+
return tokens
|
52 |
+
|
53 |
+
|
54 |
+
def load_model(model_name, checkpoint=None):
|
55 |
+
print(f'using {model_name} model')
|
56 |
+
model = torch.hub.load(
|
57 |
+
repo_or_dir=Path(sinder.__file__).parent.parent,
|
58 |
+
source='local',
|
59 |
+
model=model_name,
|
60 |
+
)
|
61 |
+
if checkpoint is not None:
|
62 |
+
states = torch.load(checkpoint, map_location='cpu')
|
63 |
+
model.load_state_dict(states, strict=False)
|
64 |
+
model = model.cuda()
|
65 |
+
model.eval()
|
66 |
+
model.interpolate_antialias = True
|
67 |
+
model.singular_defects = singular_defect_directions(model)
|
68 |
+
print(f'model loaded. patch size: {model.patch_size}')
|
69 |
+
|
70 |
+
return model
|
visualize.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from pathlib import Path
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from sinder import (
|
13 |
+
get_tokens,
|
14 |
+
load_model,
|
15 |
+
load_visual_data,
|
16 |
+
pca_array,
|
17 |
+
)
|
18 |
+
|
19 |
+
os.environ['XFORMERS_DISABLED'] = '1'
|
20 |
+
|
21 |
+
|
22 |
+
def parse_args():
|
23 |
+
parser = argparse.ArgumentParser(description='Visualize')
|
24 |
+
parser.add_argument(
|
25 |
+
'imgs', nargs='+', type=str, help='path to image/images'
|
26 |
+
)
|
27 |
+
parser.add_argument(
|
28 |
+
'--model', type=str, default='dinov2_vitg14', help='model name'
|
29 |
+
)
|
30 |
+
parser.add_argument('--workdir', type=str, default='visualize')
|
31 |
+
parser.add_argument(
|
32 |
+
'--checkpoint',
|
33 |
+
type=str,
|
34 |
+
default=None,
|
35 |
+
help='Path to checkpoint. Default is None, which loads the official pretrained weights',
|
36 |
+
)
|
37 |
+
parser.add_argument(
|
38 |
+
'--visual_size',
|
39 |
+
type=int,
|
40 |
+
default=518,
|
41 |
+
help='short side size of input image',
|
42 |
+
)
|
43 |
+
|
44 |
+
args = parser.parse_args()
|
45 |
+
return args
|
46 |
+
|
47 |
+
|
48 |
+
def visualize(args, model, visual_dataset):
|
49 |
+
model.eval()
|
50 |
+
|
51 |
+
for d in tqdm(range(len(visual_dataset))):
|
52 |
+
visual_image = visual_dataset[d]
|
53 |
+
visual_tokens_all = get_tokens(model, visual_image)
|
54 |
+
visual_tokens, visual_tokens_cls = zip(*visual_tokens_all)
|
55 |
+
filename = Path(visual_dataset.files[d]).stem
|
56 |
+
|
57 |
+
t = visual_tokens[-1].detach().cpu()
|
58 |
+
h, w, c = t.shape
|
59 |
+
norm = ((t.norm(dim=-1) / t.norm(dim=-1).max()) * 255).byte().numpy()
|
60 |
+
norm_img = Image.fromarray(norm).resize((w * 14, h * 14), 0)
|
61 |
+
norm = cv2.applyColorMap(np.array(norm_img), cv2.COLORMAP_JET)
|
62 |
+
cv2.imwrite(args.folder / f'{filename}_norm.png', norm)
|
63 |
+
|
64 |
+
pca_img = pca_array(visual_tokens[-1])
|
65 |
+
pca_img.save(args.folder / f'{filename}_pca.png')
|
66 |
+
|
67 |
+
|
68 |
+
def main():
|
69 |
+
args = parse_args()
|
70 |
+
|
71 |
+
args.folder = Path(args.workdir).expanduser()
|
72 |
+
os.makedirs(args.folder, exist_ok=True)
|
73 |
+
print(args)
|
74 |
+
print(' '.join(sys.argv))
|
75 |
+
|
76 |
+
model = load_model(args.model, args.checkpoint)
|
77 |
+
visual_dataset = load_visual_data(args, model)
|
78 |
+
visualize(args, model, visual_dataset)
|
79 |
+
|
80 |
+
|
81 |
+
if __name__ == '__main__':
|
82 |
+
main()
|