haoqiwang commited on
Commit
9ae1b1e
1 Parent(s): be446cc
README.md CHANGED
@@ -1,3 +1,47 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Official code for SINDER: Repairing the Singular Defects of DINOv2 (ECCV 2024 Oral)
2
+
3
+ [![🦢 - Paper](https://img.shields.io/badge/🦢-Paper-red)](https://arxiv.org/abs/2407.16826)
4
+
5
+ ![SINDER](./resources/high_norm.jpg)
6
+
7
+ ## Citation
8
+
9
+ ```bibtex
10
+ @InProceedings{Haoqi_2024_ECCV,
11
+ author = {Wang, Haoqi and Zhang, Tong and Salzmann, Mathieu},
12
+ title = {{SINDER}: Repairing the Singular Defects of DINOv2},
13
+ booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
14
+ month = {September},
15
+ year = {2024}
16
+ }
17
+ ```
18
+
19
+ ## Install
20
+
21
+ ```bash
22
+ conda env create -f environment.yml
23
+ conda activate sinder
24
+ pip install -e .
25
+ ```
26
+
27
+ ## Train
28
+
29
+ Put ImageNet-1K dataset at `data/imagenet` folder.
30
+
31
+ ```bash
32
+ ./main.py
33
+ ```
34
+
35
+ ## Visualize
36
+
37
+ ```bash
38
+ # visualize original dinov2
39
+ ./visualize.py resources/example.jpg
40
+
41
+ # visualize sinder, ckpt download link below
42
+ ./visualize.py resources/example.jpg --checkpoint path/to/sinder.pth
43
+ ```
44
+
45
+ ## Checkpoints
46
+
47
+ [Google Drive](https://drive.google.com/file/d/1g0Aq5qXYuMmVrN9-gGwC9ybwlCDFAw-l/view?usp=sharing)
__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .data import load_data, load_visual_data
2
+ from .neighbor_loss import get_neighbor_loss
3
+ from .repair import replace_back, replace_linear_addition_noqk
4
+ from .singular_defect import singular_defect_directions
5
+ from .utils import get_tokens, load_model, pca_array
6
+
7
+ __all__ = [
8
+ 'singular_defect_directions',
9
+ 'get_neighbor_loss',
10
+ 'pca_array',
11
+ 'get_tokens',
12
+ 'load_data',
13
+ 'load_visual_data',
14
+ 'replace_back',
15
+ 'replace_linear_addition_noqk',
16
+ 'load_model',
17
+ ]
data.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from glob import glob
2
+
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ def make_transform(
9
+ smaller_edge_size: int, patch_size, center_crop=False, max_edge_size=812
10
+ ) -> transforms.Compose:
11
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
12
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
13
+ interpolation_mode = transforms.InterpolationMode.BICUBIC
14
+ assert smaller_edge_size > 0
15
+
16
+ if center_crop:
17
+ return transforms.Compose(
18
+ [
19
+ transforms.Resize(
20
+ size=smaller_edge_size,
21
+ interpolation=interpolation_mode,
22
+ antialias=True,
23
+ ),
24
+ transforms.CenterCrop(smaller_edge_size),
25
+ transforms.ToTensor(),
26
+ transforms.Normalize(
27
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
28
+ ),
29
+ transforms.Lambda(
30
+ lambda img: img[
31
+ :,
32
+ : min(
33
+ max_edge_size,
34
+ (img.shape[1] - img.shape[1] % patch_size),
35
+ ),
36
+ : min(
37
+ max_edge_size,
38
+ (img.shape[2] - img.shape[2] % patch_size),
39
+ ),
40
+ ]
41
+ ),
42
+ ]
43
+ )
44
+ else:
45
+ return transforms.Compose(
46
+ [
47
+ transforms.Resize(
48
+ size=smaller_edge_size,
49
+ interpolation=interpolation_mode,
50
+ antialias=True,
51
+ ),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize(
54
+ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
55
+ ),
56
+ transforms.Lambda(
57
+ lambda img: img[
58
+ :,
59
+ : min(
60
+ max_edge_size,
61
+ (img.shape[1] - img.shape[1] % patch_size),
62
+ ),
63
+ : min(
64
+ max_edge_size,
65
+ (img.shape[2] - img.shape[2] % patch_size),
66
+ ),
67
+ ]
68
+ ),
69
+ ]
70
+ )
71
+
72
+
73
+ class VisualDataset(Dataset):
74
+ def __init__(self, transform, imgs=None):
75
+ self.transform = transform
76
+ if imgs is None:
77
+ self.files = [
78
+ 'resources/example.jpg',
79
+ 'resources/villa.png',
80
+ 'resources/000000037740.jpg',
81
+ 'resources/000000064359.jpg',
82
+ 'resources/000000066635.jpg',
83
+ 'resources/000000078420.jpg',
84
+ ]
85
+ else:
86
+ self.files = imgs
87
+
88
+ def __len__(self):
89
+ return len(self.files)
90
+
91
+ def __getitem__(self, index):
92
+ img = self.files[index]
93
+ img = Image.open(img).convert('RGB')
94
+ if self.transform:
95
+ img = self.transform(img)
96
+ return img
97
+
98
+
99
+ class ImageNetDataset(Dataset):
100
+ def __init__(self, transform, num_train_max=1000000):
101
+ self.transform = transform
102
+ self.files = glob('data/imagenet/train/*/*.JPEG')
103
+ step = len(self.files) // num_train_max
104
+ self.files = self.files[::step]
105
+
106
+ def __len__(self):
107
+ return len(self.files)
108
+
109
+ def __getitem__(self, index):
110
+ img = Image.open(self.files[index]).convert('RGB')
111
+ img = self.transform(img)
112
+ return img
113
+
114
+
115
+ def load_data(args, model):
116
+ transform = make_transform(
117
+ args.resolution, model.patch_size, center_crop=True
118
+ )
119
+ dataset = ImageNetDataset(
120
+ transform=transform, num_train_max=args.num_train_max
121
+ )
122
+ return dataset
123
+
124
+
125
+ def load_visual_data(args, model):
126
+ transform = make_transform(
127
+ args.visual_size, model.patch_size, max_edge_size=1792
128
+ )
129
+ dataset = VisualDataset(transform=transform, imgs=vars(args).get('imgs'))
130
+ return dataset
environment.yml ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sinder
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - conda-forge
6
+ dependencies:
7
+ - _libgcc_mutex=0.1
8
+ - _openmp_mutex=4.5
9
+ - alsa-lib=1.2.12
10
+ - aom=3.9.1
11
+ - asttokens=2.4.1
12
+ - black=22.1.0
13
+ - blas=2.116
14
+ - blas-devel=3.9.0
15
+ - blue=0.9.1
16
+ - brotli-python=1.1.0
17
+ - bzip2=1.0.8
18
+ - c-ares=1.33.1
19
+ - ca-certificates=2024.8.30
20
+ - cairo=1.18.0
21
+ - certifi=2024.8.30
22
+ - cffi=1.16.0
23
+ - cfgv=3.3.1
24
+ - charset-normalizer=3.3.2
25
+ - click=8.1.7
26
+ - colorama=0.4.6
27
+ - cuda-cudart=12.1.105
28
+ - cuda-cupti=12.1.105
29
+ - cuda-libraries=12.1.0
30
+ - cuda-nvrtc=12.1.105
31
+ - cuda-nvtx=12.1.105
32
+ - cuda-opencl=12.5.39
33
+ - cuda-runtime=12.1.0
34
+ - cuda-version=12.5
35
+ - dataclasses=0.8
36
+ - dav1d=1.2.1
37
+ - dbus=1.13.6
38
+ - decorator=5.1.1
39
+ - distlib=0.3.8
40
+ - double-conversion=3.3.0
41
+ - exceptiongroup=1.2.2
42
+ - executing=2.0.1
43
+ - expat=2.6.2
44
+ - ffmpeg=6.1.2
45
+ - filelock=3.15.4
46
+ - flake8=4.0.1
47
+ - font-ttf-dejavu-sans-mono=2.37
48
+ - font-ttf-inconsolata=3.000
49
+ - font-ttf-source-code-pro=2.038
50
+ - font-ttf-ubuntu=0.83
51
+ - fontconfig=2.14.2
52
+ - fonts-conda-ecosystem=1
53
+ - fonts-conda-forge=1
54
+ - freeglut=3.2.2
55
+ - freetype=2.12.1
56
+ - fribidi=1.0.10
57
+ - gettext=0.22.5
58
+ - gettext-tools=0.22.5
59
+ - gmp=6.3.0
60
+ - gmpy2=2.1.5
61
+ - gnutls=3.7.9
62
+ - graphite2=1.3.13
63
+ - h2=4.1.0
64
+ - harfbuzz=9.0.0
65
+ - hdf5=1.14.3
66
+ - hpack=4.0.0
67
+ - hyperframe=6.0.1
68
+ - icu=75.1
69
+ - identify=2.6.0
70
+ - idna=3.7
71
+ - imath=3.1.12
72
+ - importlib-metadata=8.5.0
73
+ - ipdb=0.13.13
74
+ - ipython=8.26.0
75
+ - jasper=4.2.4
76
+ - jedi=0.19.1
77
+ - jinja2=3.1.4
78
+ - joblib=1.4.2
79
+ - keyutils=1.6.1
80
+ - krb5=1.21.3
81
+ - lame=3.100
82
+ - lcms2=2.16
83
+ - ld_impl_linux-64=2.40
84
+ - lerc=4.0.0
85
+ - libabseil=20240116.2
86
+ - libaec=1.1.3
87
+ - libasprintf=0.22.5
88
+ - libasprintf-devel=0.22.5
89
+ - libass=0.17.3
90
+ - libblas=3.9.0
91
+ - libcblas=3.9.0
92
+ - libclang-cpp18.1=18.1.8
93
+ - libclang13=19.1.0
94
+ - libcublas=12.1.0.26
95
+ - libcufft=11.0.2.4
96
+ - libcufile=1.10.1.7
97
+ - libcups=2.3.3
98
+ - libcurand=10.3.6.82
99
+ - libcurl=8.10.1
100
+ - libcusolver=11.4.4.55
101
+ - libcusparse=12.0.2.55
102
+ - libdeflate=1.21
103
+ - libdrm=2.4.122
104
+ - libedit=3.1.20191231
105
+ - libegl=1.7.0
106
+ - libev=4.33
107
+ - libexpat=2.6.2
108
+ - libffi=3.4.2
109
+ - libgcc=14.1.0
110
+ - libgcc-ng=14.1.0
111
+ - libgettextpo=0.22.5
112
+ - libgettextpo-devel=0.22.5
113
+ - libgfortran=14.1.0
114
+ - libgfortran-ng=14.1.0
115
+ - libgfortran5=14.1.0
116
+ - libgl=1.7.0
117
+ - libglib=2.80.3
118
+ - libglu=9.0.0
119
+ - libglvnd=1.7.0
120
+ - libglx=1.7.0
121
+ - libgomp=14.1.0
122
+ - libhwloc=2.11.1
123
+ - libiconv=1.17
124
+ - libidn2=2.3.7
125
+ - libjpeg-turbo=3.0.0
126
+ - liblapack=3.9.0
127
+ - liblapacke=3.9.0
128
+ - libllvm18=18.1.8
129
+ - libllvm19=19.1.0
130
+ - libnghttp2=1.58.0
131
+ - libnpp=12.0.2.50
132
+ - libnsl=2.0.1
133
+ - libnvjitlink=12.1.105
134
+ - libnvjpeg=12.1.1.14
135
+ - libopencv=4.10.0
136
+ - libopenvino=2024.4.0
137
+ - libopenvino-auto-batch-plugin=2024.4.0
138
+ - libopenvino-auto-plugin=2024.4.0
139
+ - libopenvino-hetero-plugin=2024.4.0
140
+ - libopenvino-intel-cpu-plugin=2024.4.0
141
+ - libopenvino-intel-gpu-plugin=2024.4.0
142
+ - libopenvino-intel-npu-plugin=2024.4.0
143
+ - libopenvino-ir-frontend=2024.4.0
144
+ - libopenvino-onnx-frontend=2024.4.0
145
+ - libopenvino-paddle-frontend=2024.4.0
146
+ - libopenvino-pytorch-frontend=2024.4.0
147
+ - libopenvino-tensorflow-frontend=2024.4.0
148
+ - libopenvino-tensorflow-lite-frontend=2024.4.0
149
+ - libopus=1.3.1
150
+ - libpciaccess=0.18
151
+ - libpng=1.6.44
152
+ - libpq=16.4
153
+ - libprotobuf=4.25.3
154
+ - libsqlite=3.46.0
155
+ - libssh2=1.11.0
156
+ - libstdcxx=14.1.0
157
+ - libstdcxx-ng=14.1.0
158
+ - libtasn1=4.19.0
159
+ - libtiff=4.7.0
160
+ - libunistring=0.9.10
161
+ - libuuid=2.38.1
162
+ - libva=2.22.0
163
+ - libvpx=1.14.1
164
+ - libwebp-base=1.4.0
165
+ - libxcb=1.16
166
+ - libxcrypt=4.4.36
167
+ - libxkbcommon=1.7.0
168
+ - libxml2=2.12.7
169
+ - libzlib=1.3.1
170
+ - llvm-openmp=15.0.7
171
+ - markupsafe=2.1.5
172
+ - matplotlib-inline=0.1.7
173
+ - mccabe=0.6.1
174
+ - mkl=2022.1.0
175
+ - mkl-devel=2022.1.0
176
+ - mkl-include=2022.1.0
177
+ - mpc=1.3.1
178
+ - mpfr=4.2.1
179
+ - mpmath=1.3.0
180
+ - mypy_extensions=1.0.0
181
+ - mysql-common=9.0.1
182
+ - mysql-libs=9.0.1
183
+ - ncurses=6.5
184
+ - nettle=3.9.1
185
+ - networkx=3.3
186
+ - nodeenv=1.9.1
187
+ - numpy=2.0.0
188
+ - ocl-icd=2.3.2
189
+ - opencv=4.10.0
190
+ - openexr=3.2.2
191
+ - openh264=2.4.1
192
+ - openjpeg=2.5.2
193
+ - openssl=3.3.2
194
+ - p11-kit=0.24.1
195
+ - parso=0.8.4
196
+ - pathspec=0.12.1
197
+ - pcre2=10.44
198
+ - pexpect=4.9.0
199
+ - pickleshare=0.7.5
200
+ - pillow=10.4.0
201
+ - pip=24.0
202
+ - pixman=0.43.2
203
+ - platformdirs=4.2.2
204
+ - pre-commit=3.7.1
205
+ - prompt-toolkit=3.0.47
206
+ - pthread-stubs=0.4
207
+ - ptyprocess=0.7.0
208
+ - pugixml=1.14
209
+ - pure_eval=0.2.2
210
+ - py-opencv=4.10.0
211
+ - pycodestyle=2.8.0
212
+ - pycparser=2.22
213
+ - pyflakes=2.4.0
214
+ - pygments=2.18.0
215
+ - pysocks=1.7.1
216
+ - python=3.10.14
217
+ - python_abi=3.10
218
+ - pytorch=2.3.1
219
+ - pytorch-cuda=12.1
220
+ - pytorch-mutex=1.0
221
+ - pyyaml=6.0.1
222
+ - qt6-main=6.7.2
223
+ - readline=8.2
224
+ - requests=2.32.3
225
+ - scikit-learn=1.5.1
226
+ - scipy=1.14.0
227
+ - setuptools=71.0.4
228
+ - six=1.16.0
229
+ - snappy=1.2.1
230
+ - stack_data=0.6.2
231
+ - svt-av1=2.2.1
232
+ - sympy=1.13.0
233
+ - tbb=2021.13.0
234
+ - threadpoolctl=3.5.0
235
+ - tk=8.6.13
236
+ - toml=0.10.2
237
+ - tomli=2.0.1
238
+ - torchaudio=2.3.1
239
+ - torchtriton=2.3.1
240
+ - torchvision=0.18.1
241
+ - tqdm=4.66.5
242
+ - traitlets=5.14.3
243
+ - typed-ast=1.5.5
244
+ - typing_extensions=4.12.2
245
+ - tzdata=2024a
246
+ - ukkonen=1.0.1
247
+ - urllib3=2.2.2
248
+ - virtualenv=20.26.3
249
+ - wayland=1.23.0
250
+ - wayland-protocols=1.36
251
+ - wcwidth=0.2.13
252
+ - wheel=0.43.0
253
+ - x264=1!164.3095
254
+ - x265=3.5
255
+ - xcb-util=0.4.1
256
+ - xcb-util-cursor=0.1.5
257
+ - xcb-util-image=0.4.0
258
+ - xcb-util-keysyms=0.4.1
259
+ - xcb-util-renderutil=0.3.10
260
+ - xcb-util-wm=0.4.2
261
+ - xkeyboard-config=2.42
262
+ - xorg-fixesproto=5.0
263
+ - xorg-inputproto=2.3.2
264
+ - xorg-kbproto=1.0.7
265
+ - xorg-libice=1.1.1
266
+ - xorg-libsm=1.2.4
267
+ - xorg-libx11=1.8.9
268
+ - xorg-libxau=1.0.11
269
+ - xorg-libxdmcp=1.1.3
270
+ - xorg-libxext=1.3.4
271
+ - xorg-libxfixes=5.0.3
272
+ - xorg-libxi=1.7.10
273
+ - xorg-libxrender=0.9.11
274
+ - xorg-libxtst=1.2.5
275
+ - xorg-libxxf86vm=1.1.5
276
+ - xorg-recordproto=1.14.2
277
+ - xorg-renderproto=0.11.1
278
+ - xorg-xextproto=7.3.0
279
+ - xorg-xproto=7.0.31
280
+ - xz=5.2.6
281
+ - yaml=0.2.5
282
+ - zipp=3.20.2
283
+ - zlib=1.3.1
284
+ - zstandard=0.23.0
285
+ - zstd=1.5.6
hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
hub/backbones.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = 'LVD142M'
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = 'vit_large',
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = 'mlp',
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models.dinov2 import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f'Unsupported weights: {weights}')
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(
57
+ arch_name, patch_size, num_register_tokens
58
+ )
59
+ url = (
60
+ _DINOV2_BASE_URL
61
+ + f'/{model_base_name}/{model_full_name}_pretrain.pth'
62
+ )
63
+ state_dict = torch.hub.load_state_dict_from_url(
64
+ url, map_location='cpu'
65
+ )
66
+ model.load_state_dict(state_dict, strict=True)
67
+
68
+ return model
69
+
70
+
71
+ def dinov2_vits14(
72
+ *,
73
+ pretrained: bool = True,
74
+ weights: Union[Weights, str] = Weights.LVD142M,
75
+ **kwargs,
76
+ ):
77
+ """DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M
78
+ dataset."""
79
+ return _make_dinov2_model(
80
+ arch_name='vit_small', pretrained=pretrained, weights=weights, **kwargs
81
+ )
82
+
83
+
84
+ def dinov2_vitb14(
85
+ *,
86
+ pretrained: bool = True,
87
+ weights: Union[Weights, str] = Weights.LVD142M,
88
+ **kwargs,
89
+ ):
90
+ """DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M
91
+ dataset."""
92
+ return _make_dinov2_model(
93
+ arch_name='vit_base', pretrained=pretrained, weights=weights, **kwargs
94
+ )
95
+
96
+
97
+ def dinov2_vitl14(
98
+ *,
99
+ pretrained: bool = True,
100
+ weights: Union[Weights, str] = Weights.LVD142M,
101
+ **kwargs,
102
+ ):
103
+ """DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M
104
+ dataset."""
105
+ return _make_dinov2_model(
106
+ arch_name='vit_large', pretrained=pretrained, weights=weights, **kwargs
107
+ )
108
+
109
+
110
+ def dinov2_vitg14(
111
+ *,
112
+ pretrained: bool = True,
113
+ weights: Union[Weights, str] = Weights.LVD142M,
114
+ **kwargs,
115
+ ):
116
+ """DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M
117
+ dataset."""
118
+ return _make_dinov2_model(
119
+ arch_name='vit_giant2',
120
+ ffn_layer='swiglufused',
121
+ weights=weights,
122
+ pretrained=pretrained,
123
+ **kwargs,
124
+ )
125
+
126
+
127
+ def dinov2_vits14_reg(
128
+ *,
129
+ pretrained: bool = True,
130
+ weights: Union[Weights, str] = Weights.LVD142M,
131
+ **kwargs,
132
+ ):
133
+ """DINOv2 ViT-S/14 model with registers (optionally) pretrained on the
134
+ LVD-142M dataset."""
135
+ return _make_dinov2_model(
136
+ arch_name='vit_small',
137
+ pretrained=pretrained,
138
+ weights=weights,
139
+ num_register_tokens=4,
140
+ interpolate_antialias=True,
141
+ interpolate_offset=0.0,
142
+ **kwargs,
143
+ )
144
+
145
+
146
+ def dinov2_vitb14_reg(
147
+ *,
148
+ pretrained: bool = True,
149
+ weights: Union[Weights, str] = Weights.LVD142M,
150
+ **kwargs,
151
+ ):
152
+ """DINOv2 ViT-B/14 model with registers (optionally) pretrained on the
153
+ LVD-142M dataset."""
154
+ return _make_dinov2_model(
155
+ arch_name='vit_base',
156
+ pretrained=pretrained,
157
+ weights=weights,
158
+ num_register_tokens=4,
159
+ interpolate_antialias=True,
160
+ interpolate_offset=0.0,
161
+ **kwargs,
162
+ )
163
+
164
+
165
+ def dinov2_vitl14_reg(
166
+ *,
167
+ pretrained: bool = True,
168
+ weights: Union[Weights, str] = Weights.LVD142M,
169
+ **kwargs,
170
+ ):
171
+ """DINOv2 ViT-L/14 model with registers (optionally) pretrained on the
172
+ LVD-142M dataset."""
173
+ return _make_dinov2_model(
174
+ arch_name='vit_large',
175
+ pretrained=pretrained,
176
+ weights=weights,
177
+ num_register_tokens=4,
178
+ interpolate_antialias=True,
179
+ interpolate_offset=0.0,
180
+ **kwargs,
181
+ )
182
+
183
+
184
+ def dinov2_vitg14_reg(
185
+ *,
186
+ pretrained: bool = True,
187
+ weights: Union[Weights, str] = Weights.LVD142M,
188
+ **kwargs,
189
+ ):
190
+ """DINOv2 ViT-g/14 model with registers (optionally) pretrained on the
191
+ LVD-142M dataset."""
192
+ return _make_dinov2_model(
193
+ arch_name='vit_giant2',
194
+ ffn_layer='swiglufused',
195
+ weights=weights,
196
+ pretrained=pretrained,
197
+ num_register_tokens=4,
198
+ interpolate_antialias=True,
199
+ interpolate_offset=0.0,
200
+ **kwargs,
201
+ )
hub/utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ _DINOV2_BASE_URL = 'https://dl.fbaipublicfiles.com/dinov2'
14
+
15
+
16
+ def _make_dinov2_model_name(
17
+ arch_name: str, patch_size: int, num_register_tokens: int = 0
18
+ ) -> str:
19
+ compact_arch_name = arch_name.replace('_', '')[:4]
20
+ registers_suffix = (
21
+ f'_reg{num_register_tokens}' if num_register_tokens else ''
22
+ )
23
+ return f'dinov2_{compact_arch_name}{patch_size}{registers_suffix}'
24
+
25
+
26
+ class CenterPadding(nn.Module):
27
+ def __init__(self, multiple):
28
+ super().__init__()
29
+ self.multiple = multiple
30
+
31
+ def _get_pad(self, size):
32
+ new_size = math.ceil(size / self.multiple) * self.multiple
33
+ pad_size = new_size - size
34
+ pad_size_left = pad_size // 2
35
+ pad_size_right = pad_size - pad_size_left
36
+ return pad_size_left, pad_size_right
37
+
38
+ @torch.inference_mode()
39
+ def forward(self, x):
40
+ pads = list(
41
+ itertools.chain.from_iterable(
42
+ self._get_pad(m) for m in x.shape[:1:-1]
43
+ )
44
+ )
45
+ output = F.pad(x, pads)
46
+ return output
hubconf.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from sinder.hub.backbones import (
8
+ dinov2_vitb14,
9
+ dinov2_vitb14_reg,
10
+ dinov2_vitg14,
11
+ dinov2_vitg14_reg,
12
+ dinov2_vitl14,
13
+ dinov2_vitl14_reg,
14
+ dinov2_vits14,
15
+ dinov2_vits14_reg,
16
+ )
17
+
18
+ dependencies = ['torch']
19
+
20
+ __all__ = [
21
+ 'dinov2_vitb14',
22
+ 'dinov2_vitb14_reg',
23
+ 'dinov2_vitg14',
24
+ 'dinov2_vitg14_reg',
25
+ 'dinov2_vitl14',
26
+ 'dinov2_vitl14_reg',
27
+ 'dinov2_vits14',
28
+ 'dinov2_vits14_reg',
29
+ ]
main.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import argparse
3
+ import os
4
+ import sys
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ from sinder import (
13
+ get_neighbor_loss,
14
+ get_tokens,
15
+ load_data,
16
+ load_model,
17
+ load_visual_data,
18
+ pca_array,
19
+ replace_back,
20
+ replace_linear_addition_noqk,
21
+ )
22
+
23
+ os.environ['XFORMERS_DISABLED'] = '1'
24
+ torch.set_float32_matmul_precision('high')
25
+
26
+
27
+ def parse_args():
28
+ parser = argparse.ArgumentParser(description='Beautify network')
29
+ parser.add_argument(
30
+ '--model', type=str, default='dinov2_vitg14', help='config file'
31
+ )
32
+ parser.add_argument('--work_dir', type=str, default='results')
33
+ parser.add_argument('--resolution', type=int, default=518)
34
+ parser.add_argument('--lr', type=float, default=0.005)
35
+ parser.add_argument('--max_iter', type=int, default=30000)
36
+ parser.add_argument('--num_train_max', type=int, default=30000)
37
+ parser.add_argument('--mask_thr', type=float, default=4)
38
+ parser.add_argument('--skip_less_than', type=int, default=3)
39
+ parser.add_argument('--visual_size', type=int, default=448 * 2)
40
+ parser.add_argument('--kernel', type=int, default=3)
41
+ parser.add_argument('--save_at_skip', type=int, nargs='+', default=[75])
42
+ parser.add_argument('--limit_layers', type=int, default=10)
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def prepare_train(args, model):
49
+ model.train()
50
+
51
+ all_params = []
52
+ for name, param in model.named_parameters():
53
+ param.requires_grad = False
54
+
55
+ replace_linear_addition_noqk(model, 'model')
56
+ for name, param in model.named_parameters():
57
+ if '.epsilon' in name and param.requires_grad is True:
58
+ all_params.append(param)
59
+
60
+ grad_params = []
61
+ for name, param in model.named_parameters():
62
+ if param.requires_grad:
63
+ grad_params.append(name)
64
+
65
+ assert len(grad_params) == len(all_params)
66
+ print(len(grad_params), grad_params)
67
+ print(len(all_params), all_params)
68
+ optimizer = torch.optim.SGD(
69
+ all_params,
70
+ lr=args.lr,
71
+ momentum=0.9,
72
+ )
73
+
74
+ return optimizer
75
+
76
+
77
+ def save_model(args, model):
78
+ print('save model')
79
+ model.eval()
80
+
81
+ replace_back(model, 'model')
82
+
83
+ torch.save(model.state_dict(), args.folder / 'model.pt')
84
+
85
+
86
+ def train(args, model, dataset, optimizer, visual_dataset):
87
+ print('training')
88
+ skip_history = [False] * 1000
89
+ model.train()
90
+
91
+ for global_iter in tqdm(range(args.max_iter)):
92
+ img = dataset[global_iter % len(dataset)]
93
+ H = img.shape[1] // model.patch_size
94
+ W = img.shape[2] // model.patch_size
95
+ density = np.array(skip_history[-1000:]).astype(float).mean()
96
+ print(f'{global_iter=} {W=} {H=} {density=:.2f}')
97
+
98
+ for percent in args.save_at_skip:
99
+ if percent / 100 <= density:
100
+ print(f'save checkpoint at {density=}')
101
+ args.save_at_skip.remove(percent)
102
+ torch.save(model, args.folder / f'checkpoint_p{percent}.pth')
103
+ if len(args.save_at_skip) == 0:
104
+ break
105
+
106
+ model.zero_grad()
107
+
108
+ model.train()
109
+ with torch.enable_grad():
110
+ image_batch = img.unsqueeze(0).cuda()
111
+ result = get_neighbor_loss(
112
+ model,
113
+ image_batch,
114
+ skip_less_than=args.skip_less_than,
115
+ mask_thr=args.mask_thr,
116
+ kernel=args.kernel,
117
+ )
118
+
119
+ if result is None:
120
+ skip_history.append(True)
121
+ print('no loss, skip')
122
+ else:
123
+ skip_history.append(False)
124
+ (
125
+ layer,
126
+ loss,
127
+ I,
128
+ J,
129
+ T,
130
+ alpha,
131
+ mask_angle,
132
+ x_token,
133
+ ) = result
134
+
135
+ print(
136
+ f'{global_iter=}, {layer=}, {density=}, {alpha=:.2f}, {len(I)=}, '
137
+ f'{loss.item()=:.2f}'
138
+ )
139
+
140
+ if torch.isnan(loss).any():
141
+ print('nan loss, skip')
142
+ continue
143
+ loss.backward()
144
+
145
+ # set some grad to 0
146
+ if args.limit_layers:
147
+ with torch.no_grad():
148
+ for t in range(layer - args.limit_layers + 1):
149
+ for p in model.blocks[t].parameters():
150
+ p.grad = None
151
+
152
+ has_nan = False
153
+ for name, param in model.named_parameters():
154
+ if param.grad is not None and torch.isnan(param.grad).any():
155
+ print(f'nan grad at {name}, skip')
156
+ has_nan = True
157
+ if has_nan:
158
+ continue
159
+ optimizer.step()
160
+
161
+ # visualize
162
+ if global_iter % 100 == 0:
163
+ try:
164
+ print(f'visualization at {global_iter=}')
165
+ pca_img = pca_array(x_token)
166
+ pca_img.save(args.folder / 'pca.png')
167
+ mask_img = Image.fromarray(
168
+ (mask_angle * 255)
169
+ .detach()
170
+ .cpu()
171
+ .numpy()
172
+ .astype(np.uint8)
173
+ ).resize((W * 7, H * 7), resample=Image.NEAREST)
174
+ mask_img.save(args.folder / 'mask.png')
175
+ Image.fromarray(
176
+ (
177
+ (
178
+ img.permute((1, 2, 0)).cpu().numpy() * 0.22
179
+ + 0.45
180
+ )
181
+ * 255
182
+ )
183
+ .clip(0, 255)
184
+ .astype(np.uint8)
185
+ ).save(args.folder / 'img.png')
186
+ if global_iter % 1000 == 0:
187
+ pca_img.save(args.folder / f'{global_iter:05}_pca.png')
188
+ mask_img.save(
189
+ args.folder / f'{global_iter:05}_mask.png'
190
+ )
191
+ Image.fromarray(
192
+ (
193
+ (
194
+ img.permute((1, 2, 0)).cpu().numpy() * 0.22
195
+ + 0.45
196
+ )
197
+ * 255
198
+ )
199
+ .clip(0, 255)
200
+ .astype(np.uint8)
201
+ ).save(args.folder / f'{global_iter:05}_img.png')
202
+ except Exception as e:
203
+ print(e)
204
+
205
+ for d in range(len(visual_dataset)):
206
+ visual_image = visual_dataset[d]
207
+ visual_tokens_all = get_tokens(model, visual_image)
208
+ visual_tokens, visual_tokens_cls = zip(*visual_tokens_all)
209
+ pca_img = pca_array(visual_tokens[-1])
210
+ pca_img.save(args.folder / f'{d}_pca.png')
211
+ if global_iter % 500 == 0:
212
+ pca_img.save(
213
+ args.folder / f'{global_iter:05}_{d}_pca.png'
214
+ )
215
+
216
+ torch.save(model, args.folder / 'checkpoint.pth')
217
+
218
+
219
+ def main():
220
+ print('Start beautify')
221
+ args = parse_args()
222
+
223
+ name = f'res{args.resolution}_lr{args.lr}_{args.num_train_max}_skipless{args.skip_less_than}_maskthr{args.mask_thr}_limit{args.limit_layers}_ker{args.kernel}'
224
+ args.folder = Path(args.work_dir) / name
225
+ os.makedirs(args.folder, exist_ok=True)
226
+ print(args)
227
+ print(' '.join(sys.argv))
228
+ print(f'work dir {args.folder}')
229
+ model = load_model(args.model)
230
+ dataset = load_data(args, model)
231
+ visual_dataset = load_visual_data(args, model)
232
+ optimizer = prepare_train(args, model)
233
+ train(args, model, dataset, optimizer, visual_dataset)
234
+ save_model(args, model)
235
+
236
+
237
+ if __name__ == '__main__':
238
+ main()
models/dinov2/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from . import vision_transformer as vits
9
+
10
+ logger = logging.getLogger('dinov2')
11
+
12
+
13
+ def build_model(args, only_teacher=False, img_size=224):
14
+ args.arch = args.arch.removesuffix('_memeff')
15
+ if 'vit' in args.arch:
16
+ vit_kwargs = dict(
17
+ img_size=img_size,
18
+ patch_size=args.patch_size,
19
+ init_values=args.layerscale,
20
+ ffn_layer=args.ffn_layer,
21
+ block_chunks=args.block_chunks,
22
+ qkv_bias=args.qkv_bias,
23
+ proj_bias=args.proj_bias,
24
+ ffn_bias=args.ffn_bias,
25
+ num_register_tokens=args.num_register_tokens,
26
+ interpolate_offset=args.interpolate_offset,
27
+ interpolate_antialias=args.interpolate_antialias,
28
+ )
29
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
30
+ if only_teacher:
31
+ return teacher, teacher.embed_dim
32
+ student = vits.__dict__[args.arch](
33
+ **vit_kwargs,
34
+ drop_path_rate=args.drop_path_rate,
35
+ drop_path_uniform=args.drop_path_uniform,
36
+ )
37
+ embed_dim = student.embed_dim
38
+ return student, teacher, embed_dim
39
+
40
+
41
+ def build_model_from_cfg(cfg, only_teacher=False):
42
+ return build_model(
43
+ cfg.student,
44
+ only_teacher=only_teacher,
45
+ img_size=cfg.crops.global_crops_size,
46
+ )
models/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .attention import MemEffAttention
7
+ from .block import NestedTensorBlock
8
+ from .dino_head import DINOHead
9
+ from .mlp import Mlp
10
+ from .patch_embed import PatchEmbed
11
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
models/dinov2/layers/attention.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor, nn
15
+
16
+ logger = logging.getLogger('dinov2')
17
+
18
+
19
+ XFORMERS_ENABLED = os.environ.get('XFORMERS_DISABLED') is None
20
+ try:
21
+ if XFORMERS_ENABLED:
22
+ from xformers.ops import memory_efficient_attention, unbind
23
+
24
+ XFORMERS_AVAILABLE = True
25
+ warnings.warn('xFormers is available (Attention)')
26
+ else:
27
+ raise ImportError
28
+ except ImportError:
29
+ XFORMERS_AVAILABLE = False
30
+
31
+
32
+ class Attention(nn.Module):
33
+ def __init__(
34
+ self,
35
+ dim: int,
36
+ num_heads: int = 8,
37
+ qkv_bias: bool = False,
38
+ proj_bias: bool = True,
39
+ attn_drop: float = 0.0,
40
+ proj_drop: float = 0.0,
41
+ ) -> None:
42
+ super().__init__()
43
+ self.num_heads = num_heads
44
+ head_dim = dim // num_heads
45
+ self.scale = head_dim**-0.5
46
+
47
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
48
+ self.attn_drop = nn.Dropout(attn_drop)
49
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
50
+ self.proj_drop = nn.Dropout(proj_drop)
51
+
52
+ def forward(self, x: Tensor) -> Tensor:
53
+ B, N, C = x.shape
54
+ qkv = (
55
+ self.qkv(x)
56
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
57
+ .permute(2, 0, 3, 1, 4)
58
+ )
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError(
77
+ 'xFormers is required for using nested tensors'
78
+ )
79
+ return super().forward(x)
80
+
81
+ B, N, C = x.shape
82
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
83
+
84
+ q, k, v = unbind(qkv, 2)
85
+
86
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
87
+ x = x.reshape([B, N, C])
88
+
89
+ x = self.proj(x)
90
+ x = self.proj_drop(x)
91
+ return x
models/dinov2/layers/block.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+ from typing import Any, Callable, Dict, List, Tuple
14
+
15
+ import torch
16
+ from torch import Tensor, nn
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+ logger = logging.getLogger('dinov2')
24
+
25
+
26
+ XFORMERS_ENABLED = os.environ.get('XFORMERS_DISABLED') is None
27
+ try:
28
+ if XFORMERS_ENABLED:
29
+ from xformers.ops import fmha, index_select_cat, scaled_index_add
30
+
31
+ XFORMERS_AVAILABLE = True
32
+ warnings.warn('xFormers is available (Block)')
33
+ else:
34
+ raise ImportError
35
+ except ImportError:
36
+ XFORMERS_AVAILABLE = False
37
+
38
+
39
+ class Block(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ num_heads: int,
44
+ mlp_ratio: float = 4.0,
45
+ qkv_bias: bool = False,
46
+ proj_bias: bool = True,
47
+ ffn_bias: bool = True,
48
+ drop: float = 0.0,
49
+ attn_drop: float = 0.0,
50
+ init_values=None,
51
+ drop_path: float = 0.0,
52
+ act_layer: Callable[..., nn.Module] = nn.GELU,
53
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
54
+ attn_class: Callable[..., nn.Module] = Attention,
55
+ ffn_layer: Callable[..., nn.Module] = Mlp,
56
+ ) -> None:
57
+ super().__init__()
58
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
59
+ self.norm1 = norm_layer(dim)
60
+ self.attn = attn_class(
61
+ dim,
62
+ num_heads=num_heads,
63
+ qkv_bias=qkv_bias,
64
+ proj_bias=proj_bias,
65
+ attn_drop=attn_drop,
66
+ proj_drop=drop,
67
+ )
68
+ self.ls1 = (
69
+ LayerScale(dim, init_values=init_values)
70
+ if init_values
71
+ else nn.Identity()
72
+ )
73
+ self.drop_path1 = (
74
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
75
+ )
76
+
77
+ self.norm2 = norm_layer(dim)
78
+ mlp_hidden_dim = int(dim * mlp_ratio)
79
+ self.mlp = ffn_layer(
80
+ in_features=dim,
81
+ hidden_features=mlp_hidden_dim,
82
+ act_layer=act_layer,
83
+ drop=drop,
84
+ bias=ffn_bias,
85
+ )
86
+ self.ls2 = (
87
+ LayerScale(dim, init_values=init_values)
88
+ if init_values
89
+ else nn.Identity()
90
+ )
91
+ self.drop_path2 = (
92
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
93
+ )
94
+
95
+ self.sample_drop_ratio = drop_path
96
+
97
+ def forward(self, x: Tensor) -> Tensor:
98
+ def attn_residual_func(x: Tensor) -> Tensor:
99
+ return self.ls1(self.attn(self.norm1(x)))
100
+
101
+ def ffn_residual_func(x: Tensor) -> Tensor:
102
+ return self.ls2(self.mlp(self.norm2(x)))
103
+
104
+ if self.training and self.sample_drop_ratio > 0.1:
105
+ # the overhead is compensated only for a drop path rate larger than 0.1
106
+ x = drop_add_residual_stochastic_depth(
107
+ x,
108
+ residual_func=attn_residual_func,
109
+ sample_drop_ratio=self.sample_drop_ratio,
110
+ )
111
+ x = drop_add_residual_stochastic_depth(
112
+ x,
113
+ residual_func=ffn_residual_func,
114
+ sample_drop_ratio=self.sample_drop_ratio,
115
+ )
116
+ elif self.training and self.sample_drop_ratio > 0.0:
117
+ x = x + self.drop_path1(attn_residual_func(x))
118
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
119
+ else:
120
+ x = x + attn_residual_func(x)
121
+ x = x + ffn_residual_func(x)
122
+ return x
123
+
124
+
125
+ def drop_add_residual_stochastic_depth(
126
+ x: Tensor,
127
+ residual_func: Callable[[Tensor], Tensor],
128
+ sample_drop_ratio: float = 0.0,
129
+ ) -> Tensor:
130
+ # 1) extract subset using permutation
131
+ b, n, d = x.shape
132
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
133
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
134
+ x_subset = x[brange]
135
+
136
+ # 2) apply residual_func to get residual
137
+ residual = residual_func(x_subset)
138
+
139
+ x_flat = x.flatten(1)
140
+ residual = residual.flatten(1)
141
+
142
+ residual_scale_factor = b / sample_subset_size
143
+
144
+ # 3) add the residual
145
+ x_plus_residual = torch.index_add(
146
+ x_flat,
147
+ 0,
148
+ brange,
149
+ residual.to(dtype=x.dtype),
150
+ alpha=residual_scale_factor,
151
+ )
152
+ return x_plus_residual.view_as(x)
153
+
154
+
155
+ def get_branges_scales(x, sample_drop_ratio=0.0):
156
+ b, n, d = x.shape
157
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
158
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
159
+ residual_scale_factor = b / sample_subset_size
160
+ return brange, residual_scale_factor
161
+
162
+
163
+ def add_residual(
164
+ x, brange, residual, residual_scale_factor, scaling_vector=None
165
+ ):
166
+ if scaling_vector is None:
167
+ x_flat = x.flatten(1)
168
+ residual = residual.flatten(1)
169
+ x_plus_residual = torch.index_add(
170
+ x_flat,
171
+ 0,
172
+ brange,
173
+ residual.to(dtype=x.dtype),
174
+ alpha=residual_scale_factor,
175
+ )
176
+ else:
177
+ x_plus_residual = scaled_index_add(
178
+ x,
179
+ brange,
180
+ residual.to(dtype=x.dtype),
181
+ scaling=scaling_vector,
182
+ alpha=residual_scale_factor,
183
+ )
184
+ return x_plus_residual
185
+
186
+
187
+ attn_bias_cache: Dict[Tuple, Any] = {}
188
+
189
+
190
+ def get_attn_bias_and_cat(x_list, branges=None):
191
+ """This will perform the index select, cat the tensors, and provide the
192
+ attn_bias from cache."""
193
+ batch_sizes = (
194
+ [b.shape[0] for b in branges]
195
+ if branges is not None
196
+ else [x.shape[0] for x in x_list]
197
+ )
198
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
199
+ if all_shapes not in attn_bias_cache.keys():
200
+ seqlens = []
201
+ for b, x in zip(batch_sizes, x_list):
202
+ for _ in range(b):
203
+ seqlens.append(x.shape[1])
204
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
205
+ attn_bias._batch_sizes = batch_sizes
206
+ attn_bias_cache[all_shapes] = attn_bias
207
+
208
+ if branges is not None:
209
+ cat_tensors = index_select_cat(
210
+ [x.flatten(1) for x in x_list], branges
211
+ ).view(1, -1, x_list[0].shape[-1])
212
+ else:
213
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
214
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
215
+
216
+ return attn_bias_cache[all_shapes], cat_tensors
217
+
218
+
219
+ def drop_add_residual_stochastic_depth_list(
220
+ x_list: List[Tensor],
221
+ residual_func: Callable[[Tensor, Any], Tensor],
222
+ sample_drop_ratio: float = 0.0,
223
+ scaling_vector=None,
224
+ ) -> Tensor:
225
+ # 1) generate random set of indices for dropping samples in the batch
226
+ branges_scales = [
227
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio)
228
+ for x in x_list
229
+ ]
230
+ branges = [s[0] for s in branges_scales]
231
+ residual_scale_factors = [s[1] for s in branges_scales]
232
+
233
+ # 2) get attention bias and index+concat the tensors
234
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
235
+
236
+ # 3) apply residual_func to get residual, and split the result
237
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
238
+
239
+ outputs = []
240
+ for x, brange, residual, residual_scale_factor in zip(
241
+ x_list, branges, residual_list, residual_scale_factors
242
+ ):
243
+ outputs.append(
244
+ add_residual(
245
+ x, brange, residual, residual_scale_factor, scaling_vector
246
+ ).view_as(x)
247
+ )
248
+ return outputs
249
+
250
+
251
+ class NestedTensorBlock(Block):
252
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
253
+ """x_list contains a list of tensors to nest together and run."""
254
+ assert isinstance(self.attn, MemEffAttention)
255
+
256
+ if self.training and self.sample_drop_ratio > 0.0:
257
+
258
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
259
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
260
+
261
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
262
+ return self.mlp(self.norm2(x))
263
+
264
+ x_list = drop_add_residual_stochastic_depth_list(
265
+ x_list,
266
+ residual_func=attn_residual_func,
267
+ sample_drop_ratio=self.sample_drop_ratio,
268
+ scaling_vector=self.ls1.gamma
269
+ if isinstance(self.ls1, LayerScale)
270
+ else None,
271
+ )
272
+ x_list = drop_add_residual_stochastic_depth_list(
273
+ x_list,
274
+ residual_func=ffn_residual_func,
275
+ sample_drop_ratio=self.sample_drop_ratio,
276
+ scaling_vector=self.ls2.gamma
277
+ if isinstance(self.ls1, LayerScale)
278
+ else None,
279
+ )
280
+ return x_list
281
+ else:
282
+
283
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
284
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
285
+
286
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
287
+ return self.ls2(self.mlp(self.norm2(x)))
288
+
289
+ attn_bias, x = get_attn_bias_and_cat(x_list)
290
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
291
+ x = x + ffn_residual_func(x)
292
+ return attn_bias.split(x)
293
+
294
+ def forward(self, x_or_x_list):
295
+ if isinstance(x_or_x_list, Tensor):
296
+ return super().forward(x_or_x_list)
297
+ elif isinstance(x_or_x_list, list):
298
+ if not XFORMERS_AVAILABLE:
299
+ raise AssertionError(
300
+ 'xFormers is required for using nested tensors'
301
+ )
302
+ return self.forward_nested(x_or_x_list)
303
+ else:
304
+ raise AssertionError
models/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(
26
+ nlayers,
27
+ in_dim,
28
+ bottleneck_dim,
29
+ hidden_dim=hidden_dim,
30
+ use_bn=use_bn,
31
+ bias=mlp_bias,
32
+ )
33
+ self.apply(self._init_weights)
34
+ self.last_layer = weight_norm(
35
+ nn.Linear(bottleneck_dim, out_dim, bias=False)
36
+ )
37
+ self.last_layer.weight_g.data.fill_(1)
38
+
39
+ def _init_weights(self, m):
40
+ if isinstance(m, nn.Linear):
41
+ trunc_normal_(m.weight, std=0.02)
42
+ if isinstance(m, nn.Linear) and m.bias is not None:
43
+ nn.init.constant_(m.bias, 0)
44
+
45
+ def forward(self, x):
46
+ x = self.mlp(x)
47
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
48
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
49
+ x = self.last_layer(x)
50
+ return x
51
+
52
+
53
+ def _build_mlp(
54
+ nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
55
+ ):
56
+ if nlayers == 1:
57
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
58
+ else:
59
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
60
+ if use_bn:
61
+ layers.append(nn.BatchNorm1d(hidden_dim))
62
+ layers.append(nn.GELU())
63
+ for _ in range(nlayers - 2):
64
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
65
+ if use_bn:
66
+ layers.append(nn.BatchNorm1d(hidden_dim))
67
+ layers.append(nn.GELU())
68
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
69
+ return nn.Sequential(*layers)
models/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (
19
+ x.ndim - 1
20
+ ) # work with diff dim tensors, not just 2D ConvNets
21
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
22
+ if keep_prob > 0.0:
23
+ random_tensor.div_(keep_prob)
24
+ output = x * random_tensor
25
+ return output
26
+
27
+
28
+ class DropPath(nn.Module):
29
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
30
+ residual blocks)."""
31
+
32
+ def __init__(self, drop_prob=None):
33
+ super().__init__()
34
+ self.drop_prob = drop_prob
35
+
36
+ def forward(self, x):
37
+ return drop_path(x, self.drop_prob, self.training)
models/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor, nn
12
+
13
+
14
+ class LayerScale(nn.Module):
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ init_values: Union[float, Tensor] = 1e-5,
19
+ inplace: bool = False,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.inplace = inplace
23
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
24
+
25
+ def forward(self, x: Tensor) -> Tensor:
26
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
models/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
models/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ import torch.nn as nn
13
+ from torch import Tensor
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(
66
+ in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
67
+ )
68
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
69
+
70
+ def forward(self, x: Tensor) -> Tensor:
71
+ _, _, H, W = x.shape
72
+ patch_H, patch_W = self.patch_size
73
+
74
+ assert (
75
+ H % patch_H == 0
76
+ ), f'Input image height {H} is not a multiple of patch height {patch_H}'
77
+ assert (
78
+ W % patch_W == 0
79
+ ), f'Input image width {W} is not a multiple of patch width: {patch_W}'
80
+
81
+ x = self.proj(x) # B C H W
82
+ H, W = x.size(2), x.size(3)
83
+ x = x.flatten(2).transpose(1, 2) # B HW C
84
+ x = self.norm(x)
85
+ if not self.flatten_embedding:
86
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
87
+ return x
88
+
89
+ def flops(self) -> float:
90
+ Ho, Wo = self.patches_resolution
91
+ flops = (
92
+ Ho
93
+ * Wo
94
+ * self.embed_dim
95
+ * self.in_chans
96
+ * (self.patch_size[0] * self.patch_size[1])
97
+ )
98
+ if self.norm is not None:
99
+ flops += Ho * Wo * self.embed_dim
100
+ return flops
models/dinov2/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import warnings
8
+ from typing import Callable, Optional
9
+
10
+ import torch.nn.functional as F
11
+ from torch import Tensor, nn
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get('XFORMERS_DISABLED') is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ warnings.warn('xFormers is available (SwiGLU)')
44
+ else:
45
+ raise ImportError
46
+ except ImportError:
47
+ SwiGLU = SwiGLUFFN
48
+ XFORMERS_AVAILABLE = False
49
+
50
+
51
+ class SwiGLUFFNFused(SwiGLU):
52
+ def __init__(
53
+ self,
54
+ in_features: int,
55
+ hidden_features: Optional[int] = None,
56
+ out_features: Optional[int] = None,
57
+ act_layer: Callable[..., nn.Module] = None,
58
+ drop: float = 0.0,
59
+ bias: bool = True,
60
+ ) -> None:
61
+ out_features = out_features or in_features
62
+ hidden_features = hidden_features or in_features
63
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
64
+ super().__init__(
65
+ in_features=in_features,
66
+ hidden_features=hidden_features,
67
+ out_features=out_features,
68
+ bias=bias,
69
+ )
models/dinov2/vision_transformer.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import math
12
+ from functools import partial
13
+ from typing import Callable, Sequence, Tuple, Union
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from .layers import MemEffAttention, Mlp
21
+ from .layers import NestedTensorBlock as Block
22
+ from .layers import PatchEmbed, SwiGLUFFNFused
23
+
24
+ logger = logging.getLogger('dinov2')
25
+
26
+
27
+ def named_apply(
28
+ fn: Callable,
29
+ module: nn.Module,
30
+ name='',
31
+ depth_first=True,
32
+ include_root=False,
33
+ ) -> nn.Module:
34
+ if not depth_first and include_root:
35
+ fn(module=module, name=name)
36
+ for child_name, child_module in module.named_children():
37
+ child_name = '.'.join((name, child_name)) if name else child_name
38
+ named_apply(
39
+ fn=fn,
40
+ module=child_module,
41
+ name=child_name,
42
+ depth_first=depth_first,
43
+ include_root=True,
44
+ )
45
+ if depth_first and include_root:
46
+ fn(module=module, name=name)
47
+ return module
48
+
49
+
50
+ class BlockChunk(nn.ModuleList):
51
+ def forward(self, x):
52
+ for b in self:
53
+ x = b(x)
54
+ return x
55
+
56
+
57
+ class DinoVisionTransformer(nn.Module):
58
+ def __init__(
59
+ self,
60
+ img_size=224,
61
+ patch_size=16,
62
+ in_chans=3,
63
+ embed_dim=768,
64
+ depth=12,
65
+ num_heads=12,
66
+ mlp_ratio=4.0,
67
+ qkv_bias=True,
68
+ ffn_bias=True,
69
+ proj_bias=True,
70
+ drop_path_rate=0.0,
71
+ drop_path_uniform=False,
72
+ init_values=None, # for layerscale: None or 0 => no layerscale
73
+ embed_layer=PatchEmbed,
74
+ act_layer=nn.GELU,
75
+ block_fn=Block,
76
+ ffn_layer='mlp',
77
+ block_chunks=1,
78
+ num_register_tokens=0,
79
+ interpolate_antialias=False,
80
+ interpolate_offset=0.1,
81
+ ):
82
+ """
83
+ Args:
84
+ img_size (int, tuple): input image size
85
+ patch_size (int, tuple): patch size
86
+ in_chans (int): number of input channels
87
+ embed_dim (int): embedding dimension
88
+ depth (int): depth of transformer
89
+ num_heads (int): number of attention heads
90
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
91
+ qkv_bias (bool): enable bias for qkv if True
92
+ proj_bias (bool): enable bias for proj in attn if True
93
+ ffn_bias (bool): enable bias for ffn if True
94
+ drop_path_rate (float): stochastic depth rate
95
+ drop_path_uniform (bool): apply uniform drop rate across blocks
96
+ weight_init (str): weight init scheme
97
+ init_values (float): layer-scale init values
98
+ embed_layer (nn.Module): patch embedding layer
99
+ act_layer (nn.Module): MLP activation layer
100
+ block_fn (nn.Module): transformer block class
101
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
102
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
103
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
104
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
105
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
106
+ """
107
+ super().__init__()
108
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
109
+
110
+ self.num_features = (
111
+ self.embed_dim
112
+ ) = embed_dim # num_features for consistency with other models
113
+ self.num_tokens = 1
114
+ self.n_blocks = depth
115
+ self.num_heads = num_heads
116
+ self.patch_size = patch_size
117
+ self.num_register_tokens = num_register_tokens
118
+ self.interpolate_antialias = interpolate_antialias
119
+ self.interpolate_offset = interpolate_offset
120
+
121
+ self.patch_embed = embed_layer(
122
+ img_size=img_size,
123
+ patch_size=patch_size,
124
+ in_chans=in_chans,
125
+ embed_dim=embed_dim,
126
+ )
127
+ num_patches = self.patch_embed.num_patches
128
+
129
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
130
+ self.pos_embed = nn.Parameter(
131
+ torch.zeros(1, num_patches + self.num_tokens, embed_dim)
132
+ )
133
+ assert num_register_tokens >= 0
134
+ self.register_tokens = (
135
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
136
+ if num_register_tokens
137
+ else None
138
+ )
139
+
140
+ if drop_path_uniform is True:
141
+ dpr = [drop_path_rate] * depth
142
+ else:
143
+ dpr = [
144
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
145
+ ] # stochastic depth decay rule
146
+
147
+ if ffn_layer == 'mlp':
148
+ logger.info('using MLP layer as FFN')
149
+ ffn_layer = Mlp
150
+ elif ffn_layer == 'swiglufused' or ffn_layer == 'swiglu':
151
+ logger.info('using SwiGLU layer as FFN')
152
+ ffn_layer = SwiGLUFFNFused
153
+ elif ffn_layer == 'identity':
154
+ logger.info('using Identity layer as FFN')
155
+
156
+ def f(*args, **kwargs):
157
+ return nn.Identity()
158
+
159
+ ffn_layer = f
160
+ else:
161
+ raise NotImplementedError
162
+
163
+ blocks_list = [
164
+ block_fn(
165
+ dim=embed_dim,
166
+ num_heads=num_heads,
167
+ mlp_ratio=mlp_ratio,
168
+ qkv_bias=qkv_bias,
169
+ proj_bias=proj_bias,
170
+ ffn_bias=ffn_bias,
171
+ drop_path=dpr[i],
172
+ norm_layer=norm_layer,
173
+ act_layer=act_layer,
174
+ ffn_layer=ffn_layer,
175
+ init_values=init_values,
176
+ )
177
+ for i in range(depth)
178
+ ]
179
+ if block_chunks > 0:
180
+ self.chunked_blocks = True
181
+ chunked_blocks = []
182
+ chunksize = depth // block_chunks
183
+ for i in range(0, depth, chunksize):
184
+ # this is to keep the block index consistent if we chunk the block list
185
+ chunked_blocks.append(
186
+ [nn.Identity()] * i + blocks_list[i : i + chunksize]
187
+ )
188
+ self.blocks = nn.ModuleList(
189
+ [BlockChunk(p) for p in chunked_blocks]
190
+ )
191
+ else:
192
+ self.chunked_blocks = False
193
+ self.blocks = nn.ModuleList(blocks_list)
194
+
195
+ self.norm = norm_layer(embed_dim)
196
+ self.head = nn.Identity()
197
+
198
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
199
+
200
+ self.init_weights()
201
+
202
+ def init_weights(self):
203
+ trunc_normal_(self.pos_embed, std=0.02)
204
+ nn.init.normal_(self.cls_token, std=1e-6)
205
+ if self.register_tokens is not None:
206
+ nn.init.normal_(self.register_tokens, std=1e-6)
207
+ named_apply(init_weights_vit_timm, self)
208
+
209
+ def interpolate_pos_encoding(self, x, w, h):
210
+ previous_dtype = x.dtype
211
+ npatch = x.shape[1] - 1
212
+ N = self.pos_embed.shape[1] - 1
213
+ if npatch == N and w == h:
214
+ return self.pos_embed
215
+ pos_embed = self.pos_embed.float()
216
+ class_pos_embed = pos_embed[:, 0]
217
+ patch_pos_embed = pos_embed[:, 1:]
218
+ dim = x.shape[-1]
219
+ w0 = w // self.patch_size
220
+ h0 = h // self.patch_size
221
+ M = int(
222
+ math.sqrt(N)
223
+ ) # Recover the number of patches in each dimension
224
+ assert N == M * M
225
+ kwargs = {}
226
+ if self.interpolate_offset:
227
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
228
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
229
+ sx = float(w0 + self.interpolate_offset) / M
230
+ sy = float(h0 + self.interpolate_offset) / M
231
+ kwargs['scale_factor'] = (sx, sy)
232
+ else:
233
+ # Simply specify an output size instead of a scale factor
234
+ kwargs['size'] = (w0, h0)
235
+ patch_pos_embed = nn.functional.interpolate(
236
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
237
+ mode='bicubic',
238
+ antialias=self.interpolate_antialias,
239
+ **kwargs,
240
+ )
241
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
242
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
243
+ return torch.cat(
244
+ (class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1
245
+ ).to(previous_dtype)
246
+
247
+ def prepare_tokens_with_masks(self, x, masks=None):
248
+ B, nc, w, h = x.shape
249
+ x = self.patch_embed(x)
250
+ if masks is not None:
251
+ x = torch.where(
252
+ masks.unsqueeze(-1),
253
+ self.mask_token.to(x.dtype).unsqueeze(0),
254
+ x,
255
+ )
256
+
257
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
258
+ x = x + self.interpolate_pos_encoding(x, w, h)
259
+
260
+ if self.register_tokens is not None:
261
+ x = torch.cat(
262
+ (
263
+ x[:, :1],
264
+ self.register_tokens.expand(x.shape[0], -1, -1),
265
+ x[:, 1:],
266
+ ),
267
+ dim=1,
268
+ )
269
+
270
+ return x
271
+
272
+ def forward_features_list(self, x_list, masks_list):
273
+ x = [
274
+ self.prepare_tokens_with_masks(x, masks)
275
+ for x, masks in zip(x_list, masks_list)
276
+ ]
277
+ for blk in self.blocks:
278
+ x = blk(x)
279
+
280
+ all_x = x
281
+ output = []
282
+ for x, masks in zip(all_x, masks_list):
283
+ x_norm = self.norm(x)
284
+ output.append(
285
+ {
286
+ 'x_norm_clstoken': x_norm[:, 0],
287
+ 'x_norm_regtokens': x_norm[
288
+ :, 1 : self.num_register_tokens + 1
289
+ ],
290
+ 'x_norm_patchtokens': x_norm[
291
+ :, self.num_register_tokens + 1 :
292
+ ],
293
+ 'x_prenorm': x,
294
+ 'masks': masks,
295
+ }
296
+ )
297
+ return output
298
+
299
+ def forward_features(self, x, masks=None):
300
+ if isinstance(x, list):
301
+ return self.forward_features_list(x, masks)
302
+
303
+ x = self.prepare_tokens_with_masks(x, masks)
304
+
305
+ for blk in self.blocks:
306
+ x = blk(x)
307
+
308
+ x_norm = self.norm(x)
309
+ return {
310
+ 'x_norm_clstoken': x_norm[:, 0],
311
+ 'x_norm_regtokens': x_norm[:, 1 : self.num_register_tokens + 1],
312
+ 'x_norm_patchtokens': x_norm[:, self.num_register_tokens + 1 :],
313
+ 'x_prenorm': x,
314
+ 'masks': masks,
315
+ }
316
+
317
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
318
+ x = self.prepare_tokens_with_masks(x)
319
+ # If n is an int, take the n last blocks. If it's a list, take them
320
+ output, total_block_len = [], len(self.blocks)
321
+ blocks_to_take = (
322
+ range(total_block_len - n, total_block_len)
323
+ if isinstance(n, int)
324
+ else n
325
+ )
326
+ for i, blk in enumerate(self.blocks):
327
+ x = blk(x)
328
+ if i in blocks_to_take:
329
+ output.append(x)
330
+ assert len(output) == len(
331
+ blocks_to_take
332
+ ), f'only {len(output)} / {len(blocks_to_take)} blocks found'
333
+ return output
334
+
335
+ def _get_intermediate_layers_chunked(self, x, n=1):
336
+ x = self.prepare_tokens_with_masks(x)
337
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
338
+ # If n is an int, take the n last blocks. If it's a list, take them
339
+ blocks_to_take = (
340
+ range(total_block_len - n, total_block_len)
341
+ if isinstance(n, int)
342
+ else n
343
+ )
344
+ for block_chunk in self.blocks:
345
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
346
+ x = blk(x)
347
+ if i in blocks_to_take:
348
+ output.append(x)
349
+ i += 1
350
+ assert len(output) == len(
351
+ blocks_to_take
352
+ ), f'only {len(output)} / {len(blocks_to_take)} blocks found'
353
+ return output
354
+
355
+ def get_intermediate_layers(
356
+ self,
357
+ x: torch.Tensor,
358
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
359
+ reshape: bool = False,
360
+ return_class_token: bool = False,
361
+ norm=True,
362
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
363
+ if self.chunked_blocks:
364
+ outputs = self._get_intermediate_layers_chunked(x, n)
365
+ else:
366
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
367
+ if norm:
368
+ outputs = [self.norm(out) for out in outputs]
369
+ class_tokens = [out[:, 0] for out in outputs]
370
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
371
+ if reshape:
372
+ B, _, w, h = x.shape
373
+ outputs = [
374
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
375
+ .permute(0, 3, 1, 2)
376
+ .contiguous()
377
+ for out in outputs
378
+ ]
379
+ if return_class_token:
380
+ return tuple(zip(outputs, class_tokens))
381
+ return tuple(outputs)
382
+
383
+ def forward(self, *args, is_training=False, **kwargs):
384
+ ret = self.forward_features(*args, **kwargs)
385
+ if is_training:
386
+ return ret
387
+ else:
388
+ return self.head(ret['x_norm_clstoken'])
389
+
390
+
391
+ def init_weights_vit_timm(module: nn.Module, name: str = ''):
392
+ """ViT weight initialization, original timm impl (for reproducibility)"""
393
+ if isinstance(module, nn.Linear):
394
+ trunc_normal_(module.weight, std=0.02)
395
+ if module.bias is not None:
396
+ nn.init.zeros_(module.bias)
397
+
398
+
399
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
400
+ model = DinoVisionTransformer(
401
+ patch_size=patch_size,
402
+ embed_dim=384,
403
+ depth=12,
404
+ num_heads=6,
405
+ mlp_ratio=4,
406
+ block_fn=partial(Block, attn_class=MemEffAttention),
407
+ num_register_tokens=num_register_tokens,
408
+ **kwargs,
409
+ )
410
+ return model
411
+
412
+
413
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
414
+ model = DinoVisionTransformer(
415
+ patch_size=patch_size,
416
+ embed_dim=768,
417
+ depth=12,
418
+ num_heads=12,
419
+ mlp_ratio=4,
420
+ block_fn=partial(Block, attn_class=MemEffAttention),
421
+ num_register_tokens=num_register_tokens,
422
+ **kwargs,
423
+ )
424
+ return model
425
+
426
+
427
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
428
+ model = DinoVisionTransformer(
429
+ patch_size=patch_size,
430
+ embed_dim=1024,
431
+ depth=24,
432
+ num_heads=16,
433
+ mlp_ratio=4,
434
+ block_fn=partial(Block, attn_class=MemEffAttention),
435
+ num_register_tokens=num_register_tokens,
436
+ **kwargs,
437
+ )
438
+ return model
439
+
440
+
441
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
442
+ """Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per
443
+ head 64."""
444
+ model = DinoVisionTransformer(
445
+ patch_size=patch_size,
446
+ embed_dim=1536,
447
+ depth=40,
448
+ num_heads=24,
449
+ mlp_ratio=4,
450
+ block_fn=partial(Block, attn_class=MemEffAttention),
451
+ num_register_tokens=num_register_tokens,
452
+ **kwargs,
453
+ )
454
+ return model
neighbor_loss.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.functional import normalize
3
+
4
+
5
+ def check_anomaly_theoretical(
6
+ x,
7
+ H,
8
+ W,
9
+ anomaly_dir=None,
10
+ temperature=0.1,
11
+ mask_thr=0.001,
12
+ kernel=3,
13
+ ):
14
+ x_token = x[:, 1:]
15
+ B = x.shape[0]
16
+ assert B == 1
17
+ x_token = x_token.reshape(H, W, -1).contiguous()
18
+
19
+ with torch.no_grad():
20
+ feature = normalize(x_token, dim=-1)
21
+ direction = normalize(anomaly_dir, dim=-1)
22
+
23
+ logits = -(feature * direction).sum(dim=-1).abs()
24
+ prob = torch.exp(logits / temperature)
25
+
26
+ assert kernel in (3, 5)
27
+ pad = kernel // 2
28
+
29
+ w = prob.unfold(0, kernel, 1).unfold(1, kernel, 1)
30
+ w = w / w.sum(dim=(-1, -2), keepdims=True)
31
+
32
+ if kernel == 3:
33
+ gaussian = (
34
+ torch.FloatTensor(
35
+ [
36
+ 1 / 16,
37
+ 1 / 8,
38
+ 1 / 16,
39
+ 1 / 8,
40
+ 1 / 4,
41
+ 1 / 8,
42
+ 1 / 16,
43
+ 1 / 8,
44
+ 1 / 16,
45
+ ]
46
+ )
47
+ .to(w.device)
48
+ .reshape(1, 1, 3, 3)
49
+ )
50
+ elif kernel == 5:
51
+ gaussian = (
52
+ torch.tensor(
53
+ [
54
+ [1, 4, 7, 4, 1],
55
+ [4, 16, 26, 16, 4],
56
+ [7, 26, 41, 26, 7],
57
+ [4, 16, 26, 16, 4],
58
+ [1, 4, 7, 4, 1],
59
+ ]
60
+ )
61
+ .float()
62
+ .to(w.device)
63
+ / 273
64
+ )
65
+
66
+ w2 = w * gaussian
67
+
68
+ w2 = w2 / w2.sum(dim=(-1, -2), keepdims=True)
69
+
70
+ T = x_token.unfold(0, kernel, 1).unfold(1, kernel, 1)
71
+ T = (T * w2[:, :, None].to(T.device)).sum(dim=(-1, -2))
72
+
73
+ mask_full = logits < logits.mean() - mask_thr * logits.std()
74
+ mask_full[:pad, :] = False
75
+ mask_full[:, :pad] = False
76
+ mask_full[-pad:, :] = False
77
+ mask_full[:, -pad:] = False
78
+ index_tensor = torch.nonzero(mask_full.flatten()).flatten()
79
+ if len(index_tensor) == 0:
80
+ return None
81
+ rows = index_tensor // W
82
+ cols = index_tensor % W
83
+
84
+ alpha = x_token[pad:-pad, pad:-pad].norm(dim=-1).mean()
85
+
86
+ loss_neighbor = (
87
+ (x_token[rows, cols] - T[rows - pad, cols - pad]).norm(dim=-1)
88
+ ).mean() / alpha
89
+
90
+ return loss_neighbor, rows, cols, T, alpha, mask_full, x_token
91
+
92
+
93
+ def get_neighbor_loss(
94
+ model,
95
+ x,
96
+ skip_less_than=1,
97
+ mask_thr=0.001,
98
+ kernel=3,
99
+ ):
100
+ H = x.shape[2]
101
+ W = x.shape[3]
102
+ x = model.prepare_tokens_with_masks(x)
103
+
104
+ for i, blk in enumerate(model.blocks):
105
+ x = blk(x)
106
+ assert len(model.singular_defects) > 0
107
+ result = check_anomaly_theoretical(
108
+ x,
109
+ H // model.patch_size,
110
+ W // model.patch_size,
111
+ model.singular_defects[i],
112
+ mask_thr=mask_thr,
113
+ kernel=kernel,
114
+ )
115
+ if result is not None:
116
+ (
117
+ loss_neighbor,
118
+ rows,
119
+ cols,
120
+ T,
121
+ alpha,
122
+ mask_angle,
123
+ x_token,
124
+ ) = result
125
+ if len(rows) >= skip_less_than:
126
+ assert not torch.isnan(loss_neighbor).any()
127
+ return (
128
+ i,
129
+ loss_neighbor,
130
+ rows,
131
+ cols,
132
+ T,
133
+ alpha,
134
+ mask_angle,
135
+ x_token,
136
+ )
137
+ return None
pyproject.toml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68", "setuptools_scm[toml]>=8"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sinder"
7
+ requires-python = ">=3.8"
8
+ dynamic = ["version"]
9
+ dependencies = [
10
+ # Add runtime dependencies here
11
+ ]
12
+
13
+ # Enables the usage of setuptools_scm
14
+ [tool.setuptools_scm]
15
+
16
+ [tool.setuptools]
17
+ py-modules = []
repair.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class SVDLinearAddition(torch.nn.Module):
5
+ def __init__(self, linear, *args, **kwargs) -> None:
6
+ super().__init__(*args, **kwargs)
7
+ w, b = linear.weight, linear.bias
8
+
9
+ self.bias = b
10
+
11
+ with torch.no_grad():
12
+ U, S, Vt = torch.linalg.svd(w, full_matrices=False)
13
+ self.U = torch.nn.Parameter(U)
14
+ self.S = torch.nn.Parameter(S)
15
+ self.Vt = torch.nn.Parameter(Vt)
16
+ self.epsilon = torch.nn.Parameter(torch.zeros_like(S))
17
+
18
+ def forward(self, x):
19
+ w = self.U @ torch.diag(self.S + self.epsilon) @ self.Vt
20
+ x = torch.nn.functional.linear(x, w, self.bias)
21
+ return x
22
+
23
+
24
+ class SVDLinearAdditionQKV(torch.nn.Module):
25
+ def __init__(self, linear, *args, **kwargs) -> None:
26
+ super().__init__(*args, **kwargs)
27
+ w, b = linear.weight, linear.bias
28
+
29
+ self.bias = b
30
+
31
+ with torch.no_grad():
32
+ c = w.shape[0] // 3
33
+ q = w[:c]
34
+ k = w[c : 2 * c]
35
+ v = w[2 * c :]
36
+ U_q, S_q, Vt_q = torch.linalg.svd(q, full_matrices=False)
37
+ self.U_q = torch.nn.Parameter(U_q)
38
+ self.S_q = torch.nn.Parameter(S_q)
39
+ self.Vt_q = torch.nn.Parameter(Vt_q)
40
+ self.epsilon_q = torch.nn.Parameter(torch.zeros_like(S_q))
41
+ U_k, S_k, Vt_k = torch.linalg.svd(k, full_matrices=False)
42
+ self.U_k = torch.nn.Parameter(U_k)
43
+ self.S_k = torch.nn.Parameter(S_k)
44
+ self.Vt_k = torch.nn.Parameter(Vt_k)
45
+ self.epsilon_k = torch.nn.Parameter(torch.zeros_like(S_k))
46
+ U_v, S_v, Vt_v = torch.linalg.svd(v, full_matrices=False)
47
+ self.U_v = torch.nn.Parameter(U_v)
48
+ self.S_v = torch.nn.Parameter(S_v)
49
+ self.Vt_v = torch.nn.Parameter(Vt_v)
50
+ self.epsilon_v = torch.nn.Parameter(torch.zeros_like(S_v))
51
+
52
+ def forward(self, x):
53
+ w = torch.concatenate(
54
+ (
55
+ self.U_q @ torch.diag(self.S_q + self.epsilon_q) @ self.Vt_q,
56
+ self.U_k @ torch.diag(self.S_k + self.epsilon_k) @ self.Vt_k,
57
+ self.U_v @ torch.diag(self.S_v + self.epsilon_v) @ self.Vt_v,
58
+ )
59
+ )
60
+ x = torch.nn.functional.linear(x, w, self.bias)
61
+ return x
62
+
63
+
64
+ def replace_linear_addition_noqk(module, name):
65
+ for attr_str in dir(module):
66
+ target_attr = getattr(module, attr_str)
67
+ if type(target_attr) == torch.nn.Linear:
68
+ if 'qkv' in attr_str:
69
+ print('replaced: ', name, attr_str)
70
+ svd_linear_qkv = SVDLinearAdditionQKV(target_attr)
71
+ svd_linear_qkv.U_q.requires_grad = False
72
+ svd_linear_qkv.U_k.requires_grad = False
73
+ svd_linear_qkv.U_v.requires_grad = False
74
+ svd_linear_qkv.S_q.requires_grad = False
75
+ svd_linear_qkv.S_k.requires_grad = False
76
+ svd_linear_qkv.S_v.requires_grad = False
77
+ svd_linear_qkv.Vt_q.requires_grad = False
78
+ svd_linear_qkv.Vt_k.requires_grad = False
79
+ svd_linear_qkv.Vt_v.requires_grad = False
80
+ svd_linear_qkv.epsilon_q.requires_grad = False
81
+ svd_linear_qkv.epsilon_k.requires_grad = False
82
+ svd_linear_qkv.epsilon_v.requires_grad = True
83
+ svd_linear_qkv.bias.requires_grad = False
84
+ setattr(module, attr_str, svd_linear_qkv)
85
+ else:
86
+ print('replaced: ', name, attr_str)
87
+ svd_linear = SVDLinearAddition(target_attr)
88
+ svd_linear.U.requires_grad = False
89
+ svd_linear.S.requires_grad = False
90
+ svd_linear.Vt.requires_grad = False
91
+ svd_linear.bias.requires_grad = False
92
+ svd_linear.epsilon.requires_grad = True
93
+ setattr(module, attr_str, svd_linear)
94
+
95
+ for name, immediate_child_module in module.named_children():
96
+ replace_linear_addition_noqk(immediate_child_module, name)
97
+
98
+
99
+ def replace_back(module, name):
100
+ for attr_str in dir(module):
101
+ target_attr = getattr(module, attr_str)
102
+
103
+ if type(target_attr) == SVDLinearAddition:
104
+ print('replaced back: ', name, attr_str)
105
+ with torch.no_grad():
106
+ linear = torch.nn.Linear(
107
+ target_attr.Vt.shape[1],
108
+ target_attr.U.shape[0],
109
+ device=target_attr.U.device,
110
+ )
111
+ linear.weight.add_(
112
+ target_attr.U
113
+ @ torch.diag(target_attr.S + target_attr.epsilon)
114
+ @ target_attr.Vt
115
+ - linear.weight
116
+ )
117
+ linear.bias.add_(target_attr.bias - linear.bias)
118
+
119
+ setattr(module, attr_str, linear)
120
+
121
+ elif type(target_attr) == SVDLinearAdditionQKV:
122
+ print('replaced back: ', name, attr_str)
123
+ with torch.no_grad():
124
+ w = torch.concatenate(
125
+ (
126
+ target_attr.U_q
127
+ @ torch.diag(target_attr.S_q + target_attr.epsilon_q)
128
+ @ target_attr.Vt_q,
129
+ target_attr.U_k
130
+ @ torch.diag(target_attr.S_k + target_attr.epsilon_k)
131
+ @ target_attr.Vt_k,
132
+ target_attr.U_v
133
+ @ torch.diag(target_attr.S_v + target_attr.epsilon_v)
134
+ @ target_attr.Vt_v,
135
+ )
136
+ )
137
+ linear = torch.nn.Linear(
138
+ w.shape[1], w.shape[0], device=target_attr.U_q.device
139
+ )
140
+ linear.weight.add_(w - linear.weight)
141
+ linear.bias.add_(target_attr.bias - linear.bias)
142
+
143
+ setattr(module, attr_str, linear)
144
+
145
+ # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
146
+ for name, immediate_child_module in module.named_children():
147
+ replace_back(immediate_child_module, name)
setup.cfg ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [bdist_wheel]
2
+ universal=1
3
+
4
+ [yapf]
5
+ based_on_style = pep8
6
+ blank_line_before_nested_class_or_def = true
7
+ split_before_expression_after_opening_paren = true
8
+ allow_split_before_dict_value = False
9
+ split_penalty_import_names=0
10
+ SPLIT_PENALTY_AFTER_OPENING_BRACKET=0
11
+
12
+ [isort]
13
+ line_length = 79
14
+ extra_standard_library = pkg_resources,setuptools,logging,os,warnings,abc
15
+ known_first_party =
16
+ known_third_party = addict,cv2,matplotlib,numpy,onnx,packaging,pytest,pytorch_sphinx_theme,scipy,sphinx,torch,torchvision,yaml,yapf
17
+ no_lines_before = STDLIB,LOCALFOLDER
18
+ default_section = THIRDPARTY
19
+ profile = black
20
+
21
+ [flake8]
22
+ extend-ignore = E501,E203,E722,E266,E402,E251
23
+ exclude = .git,__pycache__
singular_defect.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # compute singular defect directions
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def anomaly_dir_attn(
7
+ blk,
8
+ identity=False,
9
+ bias=False,
10
+ centered=False,
11
+ homogeneous=False,
12
+ ):
13
+ with torch.no_grad():
14
+ N = blk.ls1.gamma.shape[0]
15
+ dev = blk.ls1.gamma.device
16
+
17
+ A4 = torch.diag(blk.ls1.gamma)
18
+ A3 = blk.attn.proj.weight
19
+ B3 = blk.attn.proj.bias
20
+ A2 = blk.attn.qkv.weight.chunk(3, dim=0)[-1]
21
+ B2 = blk.attn.qkv.bias.chunk(3, dim=0)[-1]
22
+ A1 = torch.diag(blk.norm1.weight)
23
+ B1 = blk.norm1.bias
24
+ A0 = (torch.eye(N) - 1 / N * torch.ones(N, N)).to(dev)
25
+ A = A4 @ A3 @ A2 @ A1
26
+
27
+ if centered:
28
+ A = A @ A0
29
+ B = A4 @ (A3 @ (A2 @ B1)) + A4 @ (A3 @ B2) + A4 @ B3
30
+
31
+ if bias:
32
+ A = torch.cat((A, B[:, None]), dim=1)
33
+ if homogeneous:
34
+ onehot = torch.cat(
35
+ (torch.zeros_like(B), torch.ones(1).to(dev))
36
+ )
37
+ A = torch.cat((A, onehot[None]), dim=0)
38
+
39
+ if identity:
40
+ iden = torch.eye(N).to(dev)
41
+ A[:N, :N] += iden
42
+ u, _, _ = torch.linalg.svd(A)
43
+
44
+ return u[:N, 0], A, B
45
+
46
+
47
+ def w12(blk, x):
48
+ with torch.no_grad():
49
+ x1, x2 = blk.mlp.w12(x).chunk(2, dim=-1)
50
+ return F.silu(x1) * x2
51
+
52
+
53
+ def anomaly_dir_mlp_ls(
54
+ blk,
55
+ identity=False,
56
+ bias=False,
57
+ centered=False,
58
+ homogeneous=False,
59
+ bias_ls=False,
60
+ ):
61
+ with torch.no_grad():
62
+ N = blk.ls2.gamma.shape[0]
63
+ M = blk.mlp.w3.weight.shape[1]
64
+ dev = blk.ls2.gamma.device
65
+
66
+ A4 = torch.diag(blk.ls2.gamma)
67
+ A3 = blk.mlp.w3.weight
68
+ B3 = blk.mlp.w3.bias
69
+
70
+ X = torch.randn(100000, N, device=dev)
71
+ Y = w12(blk, X)
72
+ if bias_ls:
73
+ X_one = torch.cat((X, torch.ones(100000, 1).to(dev)), dim=1)
74
+ else:
75
+ X_one = X
76
+ sol = torch.linalg.lstsq(X_one, Y)
77
+ if bias_ls:
78
+ A2 = sol.solution.T[:, :-1]
79
+ B2 = sol.solution.T[:, -1]
80
+ else:
81
+ A2 = sol.solution.T
82
+ B2 = torch.zeros(M).to(dev)
83
+
84
+ A1 = torch.diag(blk.norm2.weight)
85
+ B1 = blk.norm2.bias
86
+ A0 = (torch.eye(N) - 1 / N * torch.ones(N, N)).to(dev)
87
+ A = A4 @ A3 @ A2 @ A1
88
+
89
+ if centered:
90
+ A = A @ A0
91
+ B = A4 @ (A3 @ (A2 @ B1)) + A4 @ (A3 @ B2) + A4 @ B3
92
+
93
+ if bias:
94
+ A = torch.cat((A, B[:, None]), dim=1)
95
+ if homogeneous:
96
+ onehot = torch.cat(
97
+ (torch.zeros_like(B), torch.ones(1).to(dev))
98
+ )
99
+ A = torch.cat((A, onehot[None]), dim=0)
100
+
101
+ if identity:
102
+ iden = torch.eye(N).to(dev)
103
+ A[:N, :N] += iden
104
+ u, s, vt = torch.linalg.svd(A)
105
+
106
+ return u[:N, 0], A, B
107
+
108
+
109
+ def anomaly_dir(blk, homogeneous=False):
110
+ _, A, b = anomaly_dir_attn(
111
+ blk,
112
+ identity=True,
113
+ bias=homogeneous,
114
+ centered=True,
115
+ homogeneous=homogeneous,
116
+ )
117
+ _, C, d = anomaly_dir_mlp_ls(
118
+ blk,
119
+ identity=True,
120
+ bias=homogeneous,
121
+ bias_ls=False,
122
+ centered=True,
123
+ homogeneous=homogeneous,
124
+ )
125
+
126
+ with torch.no_grad():
127
+ N = b.shape[0]
128
+ AA = C @ A
129
+ if homogeneous:
130
+ BB = 0
131
+ else:
132
+ BB = C @ b + d
133
+ u, _, _ = torch.linalg.svd(AA)
134
+
135
+ return u[:N, 0], AA, BB
136
+
137
+
138
+ def singular_defect_directions(model):
139
+ accumulative_anomalies = []
140
+ anomaly_dab = [anomaly_dir(blk) for blk in model.blocks]
141
+ anomaly_as = [dab[1] for dab in anomaly_dab]
142
+
143
+ with torch.no_grad():
144
+ aaa = torch.eye(anomaly_as[0].shape[0]).to(anomaly_as[0])
145
+ for a in anomaly_as:
146
+ aaa = a @ aaa
147
+ u, _, _ = torch.linalg.svd(aaa)
148
+ accumulative_anomalies.append(u[:, 0])
149
+ return accumulative_anomalies
utils.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ from PIL import Image
5
+ from sklearn.decomposition import PCA
6
+
7
+ import sinder
8
+ from .singular_defect import singular_defect_directions
9
+
10
+
11
+ def pca_array(tokens, whiten=False):
12
+ h, w, c = tokens.shape
13
+ tokens = tokens.detach().cpu()
14
+
15
+ pca = PCA(n_components=3, whiten=whiten)
16
+ pca.fit(tokens.reshape(-1, c))
17
+ projected_tokens = pca.transform(tokens.reshape(-1, c))
18
+
19
+ t = torch.tensor(projected_tokens)
20
+ t_min = t.min(dim=0, keepdim=True).values
21
+ t_max = t.max(dim=0, keepdim=True).values
22
+ normalized_t = (t - t_min) / (t_max - t_min)
23
+
24
+ array = (normalized_t * 255).byte().numpy()
25
+ array = array.reshape(h, w, 3)
26
+
27
+ return Image.fromarray(array).resize((w * 7, h * 7), 0)
28
+
29
+
30
+ def get_tokens(model, image, blocks=1):
31
+ model.eval()
32
+ with torch.no_grad():
33
+ image_batch = image.unsqueeze(0).cuda()
34
+ image_batch = image_batch.cuda()
35
+ H = image_batch.shape[2]
36
+ W = image_batch.shape[3]
37
+ print(f'{W=} {H=}')
38
+ tokens = model.get_intermediate_layers(
39
+ image_batch, blocks, return_class_token=True, norm=False
40
+ )
41
+ tokens = [
42
+ (
43
+ t.reshape(
44
+ (H // model.patch_size, W // model.patch_size, t.size(-1))
45
+ ),
46
+ tc,
47
+ )
48
+ for t, tc in tokens
49
+ ]
50
+
51
+ return tokens
52
+
53
+
54
+ def load_model(model_name, checkpoint=None):
55
+ print(f'using {model_name} model')
56
+ model = torch.hub.load(
57
+ repo_or_dir=Path(sinder.__file__).parent.parent,
58
+ source='local',
59
+ model=model_name,
60
+ )
61
+ if checkpoint is not None:
62
+ states = torch.load(checkpoint, map_location='cpu')
63
+ model.load_state_dict(states, strict=False)
64
+ model = model.cuda()
65
+ model.eval()
66
+ model.interpolate_antialias = True
67
+ model.singular_defects = singular_defect_directions(model)
68
+ print(f'model loaded. patch size: {model.patch_size}')
69
+
70
+ return model
visualize.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import argparse
3
+ import os
4
+ import sys
5
+ from pathlib import Path
6
+ import cv2
7
+
8
+ import numpy as np
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ from sinder import (
13
+ get_tokens,
14
+ load_model,
15
+ load_visual_data,
16
+ pca_array,
17
+ )
18
+
19
+ os.environ['XFORMERS_DISABLED'] = '1'
20
+
21
+
22
+ def parse_args():
23
+ parser = argparse.ArgumentParser(description='Visualize')
24
+ parser.add_argument(
25
+ 'imgs', nargs='+', type=str, help='path to image/images'
26
+ )
27
+ parser.add_argument(
28
+ '--model', type=str, default='dinov2_vitg14', help='model name'
29
+ )
30
+ parser.add_argument('--workdir', type=str, default='visualize')
31
+ parser.add_argument(
32
+ '--checkpoint',
33
+ type=str,
34
+ default=None,
35
+ help='Path to checkpoint. Default is None, which loads the official pretrained weights',
36
+ )
37
+ parser.add_argument(
38
+ '--visual_size',
39
+ type=int,
40
+ default=518,
41
+ help='short side size of input image',
42
+ )
43
+
44
+ args = parser.parse_args()
45
+ return args
46
+
47
+
48
+ def visualize(args, model, visual_dataset):
49
+ model.eval()
50
+
51
+ for d in tqdm(range(len(visual_dataset))):
52
+ visual_image = visual_dataset[d]
53
+ visual_tokens_all = get_tokens(model, visual_image)
54
+ visual_tokens, visual_tokens_cls = zip(*visual_tokens_all)
55
+ filename = Path(visual_dataset.files[d]).stem
56
+
57
+ t = visual_tokens[-1].detach().cpu()
58
+ h, w, c = t.shape
59
+ norm = ((t.norm(dim=-1) / t.norm(dim=-1).max()) * 255).byte().numpy()
60
+ norm_img = Image.fromarray(norm).resize((w * 14, h * 14), 0)
61
+ norm = cv2.applyColorMap(np.array(norm_img), cv2.COLORMAP_JET)
62
+ cv2.imwrite(args.folder / f'{filename}_norm.png', norm)
63
+
64
+ pca_img = pca_array(visual_tokens[-1])
65
+ pca_img.save(args.folder / f'{filename}_pca.png')
66
+
67
+
68
+ def main():
69
+ args = parse_args()
70
+
71
+ args.folder = Path(args.workdir).expanduser()
72
+ os.makedirs(args.folder, exist_ok=True)
73
+ print(args)
74
+ print(' '.join(sys.argv))
75
+
76
+ model = load_model(args.model, args.checkpoint)
77
+ visual_dataset = load_visual_data(args, model)
78
+ visualize(args, model, visual_dataset)
79
+
80
+
81
+ if __name__ == '__main__':
82
+ main()