Vishakaraj commited on
Commit
3fad000
1 Parent(s): ad1d2d7

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .github/workflows/lint.yaml +39 -0
  3. .gitignore +11 -0
  4. .vscode/launch.json +16 -0
  5. CODE_OF_CONDUCT.md +80 -0
  6. CONTRIBUTING.md +31 -0
  7. Dockerfile +11 -0
  8. LICENSE +203 -0
  9. MODEL_CARD.md +201 -0
  10. README.md +497 -8
  11. app.py +137 -0
  12. conda-extras.yaml +24 -0
  13. conda.yaml +22 -0
  14. demo.py +153 -0
  15. dinov2/__init__.py +6 -0
  16. dinov2/__pycache__/__init__.cpython-310.pyc +0 -0
  17. dinov2/configs/__init__.py +22 -0
  18. dinov2/configs/eval/vitb14_pretrain.yaml +6 -0
  19. dinov2/configs/eval/vitg14_pretrain.yaml +7 -0
  20. dinov2/configs/eval/vitl14_pretrain.yaml +6 -0
  21. dinov2/configs/eval/vits14_pretrain.yaml +6 -0
  22. dinov2/configs/ssl_default_config.yaml +115 -0
  23. dinov2/configs/train/vitg14.yaml +26 -0
  24. dinov2/configs/train/vitl14.yaml +26 -0
  25. dinov2/configs/train/vitl16_short.yaml +6 -0
  26. dinov2/data/__init__.py +10 -0
  27. dinov2/data/adapters.py +28 -0
  28. dinov2/data/augmentations.py +118 -0
  29. dinov2/data/collate.py +49 -0
  30. dinov2/data/datasets/__init__.py +7 -0
  31. dinov2/data/datasets/decoders.py +31 -0
  32. dinov2/data/datasets/extended.py +38 -0
  33. dinov2/data/datasets/image_net.py +290 -0
  34. dinov2/data/datasets/image_net_22k.py +302 -0
  35. dinov2/data/loaders.py +222 -0
  36. dinov2/data/masking.py +86 -0
  37. dinov2/data/samplers.py +229 -0
  38. dinov2/data/transforms.py +91 -0
  39. dinov2/distributed/__init__.py +270 -0
  40. dinov2/eval/__init__.py +4 -0
  41. dinov2/eval/__pycache__/__init__.cpython-310.pyc +0 -0
  42. dinov2/eval/depth/__init__.py +4 -0
  43. dinov2/eval/depth/models/__init__.py +10 -0
  44. dinov2/eval/depth/models/backbones/__init__.py +6 -0
  45. dinov2/eval/depth/models/backbones/vision_transformer.py +16 -0
  46. dinov2/eval/depth/models/builder.py +49 -0
  47. dinov2/eval/depth/models/decode_heads/__init__.py +7 -0
  48. dinov2/eval/depth/models/decode_heads/decode_head.py +225 -0
  49. dinov2/eval/depth/models/decode_heads/dpt_head.py +270 -0
  50. dinov2/eval/depth/models/decode_heads/linear_head.py +89 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ mmcv_full-1.5.0-cp310-cp310-linux_x86_64.whl filter=lfs diff=lfs merge=lfs -text
.github/workflows/lint.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Lint
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ pull_request:
8
+ branches:
9
+ - master
10
+ - 'gh/**'
11
+
12
+ jobs:
13
+ run-linters:
14
+ name: Run linters
15
+ runs-on: ubuntu-20.04
16
+
17
+ steps:
18
+ - name: Checkout repository
19
+ uses: actions/checkout@v3
20
+ - name: Set up Python
21
+ uses: actions/setup-python@v4
22
+ with:
23
+ python-version: 3.9
24
+ cache: 'pip'
25
+ cache-dependency-path: '**/requirements*.txt'
26
+ - name: Install Python (development) dependencies
27
+ run: |
28
+ pip install -r requirements-dev.txt
29
+ - name: Run flake8
30
+ run: |
31
+ flake8
32
+ - name: Run black
33
+ if: always()
34
+ run: |
35
+ black --check dinov2
36
+ - name: Run pylint
37
+ if: always()
38
+ run: |
39
+ pylint --exit-zero dinov2
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ build/
2
+ dist/
3
+ *.egg-info/
4
+ **/__pycache__/
5
+
6
+ **/.ipynb_checkpoints
7
+ **/.ipynb_checkpoints/**
8
+
9
+ *.swp
10
+
11
+ .vscode/
.vscode/launch.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ // Use IntelliSense to learn about possible attributes.
3
+ // Hover to view descriptions of existing attributes.
4
+ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
5
+ "version": "0.2.0",
6
+ "configurations": [
7
+ {
8
+ "name": "Python: Current File",
9
+ "type": "python",
10
+ "request": "launch",
11
+ "program": "${file}",
12
+ "console": "integratedTerminal",
13
+ "justMyCode": false
14
+ }
15
+ ]
16
+ }
CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ In the interest of fostering an open and welcoming environment, we as
6
+ contributors and maintainers pledge to make participation in our project and
7
+ our community a harassment-free experience for everyone, regardless of age, body
8
+ size, disability, ethnicity, sex characteristics, gender identity and expression,
9
+ level of experience, education, socio-economic status, nationality, personal
10
+ appearance, race, religion, or sexual identity and orientation.
11
+
12
+ ## Our Standards
13
+
14
+ Examples of behavior that contributes to creating a positive environment
15
+ include:
16
+
17
+ * Using welcoming and inclusive language
18
+ * Being respectful of differing viewpoints and experiences
19
+ * Gracefully accepting constructive criticism
20
+ * Focusing on what is best for the community
21
+ * Showing empathy towards other community members
22
+
23
+ Examples of unacceptable behavior by participants include:
24
+
25
+ * The use of sexualized language or imagery and unwelcome sexual attention or
26
+ advances
27
+ * Trolling, insulting/derogatory comments, and personal or political attacks
28
+ * Public or private harassment
29
+ * Publishing others' private information, such as a physical or electronic
30
+ address, without explicit permission
31
+ * Other conduct which could reasonably be considered inappropriate in a
32
+ professional setting
33
+
34
+ ## Our Responsibilities
35
+
36
+ Project maintainers are responsible for clarifying the standards of acceptable
37
+ behavior and are expected to take appropriate and fair corrective action in
38
+ response to any instances of unacceptable behavior.
39
+
40
+ Project maintainers have the right and responsibility to remove, edit, or
41
+ reject comments, commits, code, wiki edits, issues, and other contributions
42
+ that are not aligned to this Code of Conduct, or to ban temporarily or
43
+ permanently any contributor for other behaviors that they deem inappropriate,
44
+ threatening, offensive, or harmful.
45
+
46
+ ## Scope
47
+
48
+ This Code of Conduct applies within all project spaces, and it also applies when
49
+ an individual is representing the project or its community in public spaces.
50
+ Examples of representing a project or community include using an official
51
+ project e-mail address, posting via an official social media account, or acting
52
+ as an appointed representative at an online or offline event. Representation of
53
+ a project may be further defined and clarified by project maintainers.
54
+
55
+ This Code of Conduct also applies outside the project spaces when there is a
56
+ reasonable belief that an individual's behavior may have a negative impact on
57
+ the project or its community.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported by contacting the project team at <opensource-conduct@meta.com>. All
63
+ complaints will be reviewed and investigated and will result in a response that
64
+ is deemed necessary and appropriate to the circumstances. The project team is
65
+ obligated to maintain confidentiality with regard to the reporter of an incident.
66
+ Further details of specific enforcement policies may be posted separately.
67
+
68
+ Project maintainers who do not follow or enforce the Code of Conduct in good
69
+ faith may face temporary or permanent repercussions as determined by other
70
+ members of the project's leadership.
71
+
72
+ ## Attribution
73
+
74
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
75
+ available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
76
+
77
+ [homepage]: https://www.contributor-covenant.org
78
+
79
+ For answers to common questions about this code of conduct, see
80
+ https://www.contributor-covenant.org/faq
CONTRIBUTING.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to DINOv2
2
+ We want to make contributing to this project as easy and transparent as
3
+ possible.
4
+
5
+ ## Pull Requests
6
+ We actively welcome your pull requests.
7
+
8
+ 1. Fork the repo and create your branch from `main`.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If you've changed APIs, update the documentation.
11
+ 4. Ensure the test suite passes.
12
+ 5. Make sure your code lints.
13
+ 6. If you haven't already, complete the Contributor License Agreement ("CLA").
14
+
15
+ ## Contributor License Agreement ("CLA")
16
+ In order to accept your pull request, we need you to submit a CLA. You only need
17
+ to do this once to work on any of Meta's open source projects.
18
+
19
+ Complete your CLA here: <https://code.facebook.com/cla>
20
+
21
+ ## Issues
22
+ We use GitHub issues to track public bugs. Please ensure your description is
23
+ clear and has sufficient instructions to be able to reproduce the issue.
24
+
25
+ Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe
26
+ disclosure of security bugs. In those cases, please go through the process
27
+ outlined on that page and do not file a public issue.
28
+
29
+ ## License
30
+ By contributing to DINOv2, you agree that your contributions will be licensed
31
+ under the LICENSE file in the root directory of this source tree.
Dockerfile ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.0.0-cuda11.7-cudnn8-devel
2
+
3
+ COPY . /dinov2
4
+ WORKDIR /dinov2
5
+
6
+ RUN pip install -r requirements.txt
7
+ RUN pip install -r requirements-extras.txt
8
+
9
+ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
10
+
11
+
LICENSE ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ Apache License
4
+ Version 2.0, January 2004
5
+ http://www.apache.org/licenses/
6
+
7
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
8
+
9
+ 1. Definitions.
10
+
11
+ "License" shall mean the terms and conditions for use, reproduction,
12
+ and distribution as defined by Sections 1 through 9 of this document.
13
+
14
+ "Licensor" shall mean the copyright owner or entity authorized by
15
+ the copyright owner that is granting the License.
16
+
17
+ "Legal Entity" shall mean the union of the acting entity and all
18
+ other entities that control, are controlled by, or are under common
19
+ control with that entity. For the purposes of this definition,
20
+ "control" means (i) the power, direct or indirect, to cause the
21
+ direction or management of such entity, whether by contract or
22
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
23
+ outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ "You" (or "Your") shall mean an individual or Legal Entity
26
+ exercising permissions granted by this License.
27
+
28
+ "Source" form shall mean the preferred form for making modifications,
29
+ including but not limited to software source code, documentation
30
+ source, and configuration files.
31
+
32
+ "Object" form shall mean any form resulting from mechanical
33
+ transformation or translation of a Source form, including but
34
+ not limited to compiled object code, generated documentation,
35
+ and conversions to other media types.
36
+
37
+ "Work" shall mean the work of authorship, whether in Source or
38
+ Object form, made available under the License, as indicated by a
39
+ copyright notice that is included in or attached to the work
40
+ (an example is provided in the Appendix below).
41
+
42
+ "Derivative Works" shall mean any work, whether in Source or Object
43
+ form, that is based on (or derived from) the Work and for which the
44
+ editorial revisions, annotations, elaborations, or other modifications
45
+ represent, as a whole, an original work of authorship. For the purposes
46
+ of this License, Derivative Works shall not include works that remain
47
+ separable from, or merely link (or bind by name) to the interfaces of,
48
+ the Work and Derivative Works thereof.
49
+
50
+ "Contribution" shall mean any work of authorship, including
51
+ the original version of the Work and any modifications or additions
52
+ to that Work or Derivative Works thereof, that is intentionally
53
+ submitted to Licensor for inclusion in the Work by the copyright owner
54
+ or by an individual or Legal Entity authorized to submit on behalf of
55
+ the copyright owner. For the purposes of this definition, "submitted"
56
+ means any form of electronic, verbal, or written communication sent
57
+ to the Licensor or its representatives, including but not limited to
58
+ communication on electronic mailing lists, source code control systems,
59
+ and issue tracking systems that are managed by, or on behalf of, the
60
+ Licensor for the purpose of discussing and improving the Work, but
61
+ excluding communication that is conspicuously marked or otherwise
62
+ designated in writing by the copyright owner as "Not a Contribution."
63
+
64
+ "Contributor" shall mean Licensor and any individual or Legal Entity
65
+ on behalf of whom a Contribution has been received by Licensor and
66
+ subsequently incorporated within the Work.
67
+
68
+ 2. Grant of Copyright License. Subject to the terms and conditions of
69
+ this License, each Contributor hereby grants to You a perpetual,
70
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
71
+ copyright license to reproduce, prepare Derivative Works of,
72
+ publicly display, publicly perform, sublicense, and distribute the
73
+ Work and such Derivative Works in Source or Object form.
74
+
75
+ 3. Grant of Patent License. Subject to the terms and conditions of
76
+ this License, each Contributor hereby grants to You a perpetual,
77
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
78
+ (except as stated in this section) patent license to make, have made,
79
+ use, offer to sell, sell, import, and otherwise transfer the Work,
80
+ where such license applies only to those patent claims licensable
81
+ by such Contributor that are necessarily infringed by their
82
+ Contribution(s) alone or by combination of their Contribution(s)
83
+ with the Work to which such Contribution(s) was submitted. If You
84
+ institute patent litigation against any entity (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
86
+ or a Contribution incorporated within the Work constitutes direct
87
+ or contributory patent infringement, then any patent licenses
88
+ granted to You under this License for that Work shall terminate
89
+ as of the date such litigation is filed.
90
+
91
+ 4. Redistribution. You may reproduce and distribute copies of the
92
+ Work or Derivative Works thereof in any medium, with or without
93
+ modifications, and in Source or Object form, provided that You
94
+ meet the following conditions:
95
+
96
+ (a) You must give any other recipients of the Work or
97
+ Derivative Works a copy of this License; and
98
+
99
+ (b) You must cause any modified files to carry prominent notices
100
+ stating that You changed the files; and
101
+
102
+ (c) You must retain, in the Source form of any Derivative Works
103
+ that You distribute, all copyright, patent, trademark, and
104
+ attribution notices from the Source form of the Work,
105
+ excluding those notices that do not pertain to any part of
106
+ the Derivative Works; and
107
+
108
+ (d) If the Work includes a "NOTICE" text file as part of its
109
+ distribution, then any Derivative Works that You distribute must
110
+ include a readable copy of the attribution notices contained
111
+ within such NOTICE file, excluding those notices that do not
112
+ pertain to any part of the Derivative Works, in at least one
113
+ of the following places: within a NOTICE text file distributed
114
+ as part of the Derivative Works; within the Source form or
115
+ documentation, if provided along with the Derivative Works; or,
116
+ within a display generated by the Derivative Works, if and
117
+ wherever such third-party notices normally appear. The contents
118
+ of the NOTICE file are for informational purposes only and
119
+ do not modify the License. You may add Your own attribution
120
+ notices within Derivative Works that You distribute, alongside
121
+ or as an addendum to the NOTICE text from the Work, provided
122
+ that such additional attribution notices cannot be construed
123
+ as modifying the License.
124
+
125
+ You may add Your own copyright statement to Your modifications and
126
+ may provide additional or different license terms and conditions
127
+ for use, reproduction, or distribution of Your modifications, or
128
+ for any such Derivative Works as a whole, provided Your use,
129
+ reproduction, and distribution of the Work otherwise complies with
130
+ the conditions stated in this License.
131
+
132
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
133
+ any Contribution intentionally submitted for inclusion in the Work
134
+ by You to the Licensor shall be under the terms and conditions of
135
+ this License, without any additional terms or conditions.
136
+ Notwithstanding the above, nothing herein shall supersede or modify
137
+ the terms of any separate license agreement you may have executed
138
+ with Licensor regarding such Contributions.
139
+
140
+ 6. Trademarks. This License does not grant permission to use the trade
141
+ names, trademarks, service marks, or product names of the Licensor,
142
+ except as required for reasonable and customary use in describing the
143
+ origin of the Work and reproducing the content of the NOTICE file.
144
+
145
+ 7. Disclaimer of Warranty. Unless required by applicable law or
146
+ agreed to in writing, Licensor provides the Work (and each
147
+ Contributor provides its Contributions) on an "AS IS" BASIS,
148
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
149
+ implied, including, without limitation, any warranties or conditions
150
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
151
+ PARTICULAR PURPOSE. You are solely responsible for determining the
152
+ appropriateness of using or redistributing the Work and assume any
153
+ risks associated with Your exercise of permissions under this License.
154
+
155
+ 8. Limitation of Liability. In no event and under no legal theory,
156
+ whether in tort (including negligence), contract, or otherwise,
157
+ unless required by applicable law (such as deliberate and grossly
158
+ negligent acts) or agreed to in writing, shall any Contributor be
159
+ liable to You for damages, including any direct, indirect, special,
160
+ incidental, or consequential damages of any character arising as a
161
+ result of this License or out of the use or inability to use the
162
+ Work (including but not limited to damages for loss of goodwill,
163
+ work stoppage, computer failure or malfunction, or any and all
164
+ other commercial damages or losses), even if such Contributor
165
+ has been advised of the possibility of such damages.
166
+
167
+ 9. Accepting Warranty or Additional Liability. While redistributing
168
+ the Work or Derivative Works thereof, You may choose to offer,
169
+ and charge a fee for, acceptance of support, warranty, indemnity,
170
+ or other liability obligations and/or rights consistent with this
171
+ License. However, in accepting such obligations, You may act only
172
+ on Your own behalf and on Your sole responsibility, not on behalf
173
+ of any other Contributor, and only if You agree to indemnify,
174
+ defend, and hold each Contributor harmless for any liability
175
+ incurred by, or claims asserted against, such Contributor by reason
176
+ of your accepting any such warranty or additional liability.
177
+
178
+ END OF TERMS AND CONDITIONS
179
+
180
+ APPENDIX: How to apply the Apache License to your work.
181
+
182
+ To apply the Apache License to your work, attach the following
183
+ boilerplate notice, with the fields enclosed by brackets "[]"
184
+ replaced with your own identifying information. (Don't include
185
+ the brackets!) The text should be enclosed in the appropriate
186
+ comment syntax for the file format. We also recommend that a
187
+ file or class name and description of purpose be included on the
188
+ same "printed page" as the copyright notice for easier
189
+ identification within third-party archives.
190
+
191
+ Copyright [yyyy] [name of copyright owner]
192
+
193
+ Licensed under the Apache License, Version 2.0 (the "License");
194
+ you may not use this file except in compliance with the License.
195
+ You may obtain a copy of the License at
196
+
197
+ http://www.apache.org/licenses/LICENSE-2.0
198
+
199
+ Unless required by applicable law or agreed to in writing, software
200
+ distributed under the License is distributed on an "AS IS" BASIS,
201
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
202
+ See the License for the specific language governing permissions and
203
+ limitations under the License.
MODEL_CARD.md ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Card for DINOv2-S/B/L/g
2
+
3
+ These are Vision Transformer models trained following the method described in the paper:
4
+ "DINOv2: Learning Robust Visual Features without Supervision"
5
+
6
+ We provide 4 models: 1 ViT-g trained from scratch, and 3 ViT-S/B/L models distilled from the ViT-g.
7
+
8
+ ## Model Details
9
+ The model takes an image as input and returns a class token and patch tokens.
10
+
11
+ The embedding dimension is:
12
+ - 384 for ViT-S.
13
+ - 768 for ViT-B.
14
+ - 1024 for ViT-L.
15
+ - 1536 for ViT-g.
16
+
17
+ The models follow a Transformer architecture, with a patch size of 14.
18
+
19
+ For a 224x224 image, this results in 1 class token + 256 patch tokens.
20
+
21
+ The models can accept larger images provided the image shapes are multiples of the patch size (14).
22
+ If this condition is not verified, the model will crop to the closest smaller multiple of the patch size.
23
+
24
+ ### Model Description
25
+
26
+ - **Developed by:** Meta AI
27
+ - **Model type:** Vision Transformer
28
+ - **License:** Apache License 2.0
29
+
30
+ - **Repository:** https://github.com/facebookresearch/dinov2
31
+ - **Paper:** https://arxiv.org/abs/2304.07193
32
+ - **Demo:** https://dinov2.metademolab.com/
33
+
34
+ ## Uses
35
+
36
+ The models are vision backbones providing multi-purpose features for downstream tasks.
37
+
38
+ ### Direct Use
39
+
40
+ The models can be used without fine-tuning, with downstream classifiers as simple as linear layers, to obtain competitive results:
41
+ - on depth estimation, semantic segmentation, using linear layers.
42
+ - on image classification, using k-NN classifiers on the class token.
43
+ - on image classification, with logistic regression classifiers applied on the class token.
44
+ - on image classification, with a linear layer applied on the class token and the average of the patch tokens.
45
+ - on image retrieval using nearest neighbors.
46
+
47
+ ### Downstream Use
48
+
49
+ It is technically possible to perform fine-tuning on the models, for small gains (we measured +2% on ImageNet-1k classification).
50
+ We recommend keeping this as a very last step and only when necessary, as the features already provide good performance out-of-the-box.
51
+
52
+ ## Bias, Risks, and Limitations
53
+
54
+ Despite improvements thanks to the training method not using annotations, we still observe significant biases in our models toward rich households from Western countries.
55
+
56
+ ### Recommendations
57
+
58
+ We expect fine-tuning will increase the biases in the features produced by the model as they will be tuned to the fine-tuning labels.
59
+
60
+ ## How to Get Started with the Model
61
+
62
+ Use the code below to get started with the model.
63
+
64
+ ```python
65
+ import torch
66
+ dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
67
+ dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
68
+ dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
69
+ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
70
+ ```
71
+
72
+ ## Training Details
73
+
74
+ ### Training Data
75
+
76
+ - **Training data:** LVD-142M (see paper)
77
+ - **Training regime:** fp16 using PyTorch-FSDP mixed-precision.
78
+
79
+ ### Training Procedure
80
+
81
+ - **Training objective:**
82
+ - DINO self-distillation loss with multi-crop
83
+ - iBOT masked-image modeling loss
84
+ - KoLeo regularization on [CLS] tokens
85
+ - **Architectures:**
86
+ - ViT-S (21M params): Patch size 14, embedding dimension 384, 6 heads, MLP FFN
87
+ - ViT-B (86M params): Patch size 14, embedding dimension 768, 12 heads, MLP FFN
88
+ - ViT-L (0.3B params): Patch size 14, embedding dimension 1024, 16 heads, MLP FFN
89
+ - ViT-g (1.1B params): Patch size 14, embedding dimension 1536, 24 heads, SwiGLU FFN
90
+ - **Distillation:**
91
+ - Distillation follows the standard DINOv2 pretraining procedure, except the teacher is a pretrained ViT-g, frozen.
92
+
93
+ ## Evaluation
94
+
95
+ We refer users to the associated paper for the evaluation protocols.
96
+
97
+ <table>
98
+ <tr>
99
+ <th>model</th>
100
+ <th colspan="3">ImageNet-1k</th>
101
+ <th>NYU-Depth v2</th>
102
+ <th>SUN-RGBD</th>
103
+ <th>ADE20k</th>
104
+ <th>iNaturalist 2018</th>
105
+ <th>Oxford-H</th>
106
+ </tr>
107
+ <tr>
108
+ <th rowspan="2">task</th>
109
+ <th>classif. (acc)</th>
110
+ <th>classif. (acc)</th>
111
+ <th>classif. V2 (acc)</th>
112
+ <th>depth (RMSE)</th>
113
+ <th>depth (RMSE)</th>
114
+ <th>segm. (mAP)</th>
115
+ <th>classif. (acc)</th>
116
+ <th>retrieval (mAP)</th>
117
+ </tr>
118
+ <tr>
119
+ <!-- <th>^</th> -->
120
+ <th>k-NN</th>
121
+ <th>linear</th>
122
+ <th>linear</th>
123
+ <th>linear<br />4 layers</th>
124
+ <th>NYU-D transfer</th>
125
+ <th>multiscale</th>
126
+ <th>linear</th>
127
+ <th>nearest neighbor</th>
128
+ </tr>
129
+ <tr>
130
+ <td>ViT-S/14</td>
131
+ <td align="right">79.0%</td>
132
+ <td align="right">81.1%</td>
133
+ <td align="right">70.8%</td>
134
+ <td align="right">0.417</td>
135
+ <td align="right">0.431</td>
136
+ <td align="right">47.2</td>
137
+ <td align="right">69.5%</td>
138
+ <td align="right">43.2</td>
139
+ </tr>
140
+ <tr>
141
+ <td>ViT-B/14</td>
142
+ <td align="right">82.1%</td>
143
+ <td align="right">84.5%</td>
144
+ <td align="right">74.9%</td>
145
+ <td align="right">0.362</td>
146
+ <td align="right">0.400</td>
147
+ <td align="right">51.3</td>
148
+ <td align="right">76.3%</td>
149
+ <td align="right">49.5</td>
150
+ </tr>
151
+ <tr>
152
+ <td>ViT-L/14</td>
153
+ <td align="right">83.5%</td>
154
+ <td align="right">86.3%</td>
155
+ <td align="right">77.6%</td>
156
+ <td align="right">0.333</td>
157
+ <td align="right">0.396</td>
158
+ <td align="right">53.1</td>
159
+ <td align="right">79.8%</td>
160
+ <td align="right">54.0</td>
161
+ </tr>
162
+ <tr>
163
+ <td>ViT-g/14</td>
164
+ <td align="right">83.5%</td>
165
+ <td align="right">86.5%</td>
166
+ <td align="right">78.4%</td>
167
+ <td align="right">0.298</td>
168
+ <td align="right">0.362</td>
169
+ <td align="right">53.0</td>
170
+ <td align="right">81.6%</td>
171
+ <td align="right">52.3</td>
172
+ </tr>
173
+ </table>
174
+
175
+ ## Environmental Impact
176
+
177
+ - **Hardware Type:** Nvidia A100
178
+ - **Hours used:** 22,000 for ViT-g, 4,500 for ViT-S distillation, 5,300 for ViT-B distillation, 8,000 for ViT-L distillation
179
+ - **Cloud Provider:** Private infra
180
+ - **Compute Region:** USA
181
+ - **Carbon Emitted:** 7t CO2eq
182
+
183
+ #### Hardware
184
+
185
+ Nvidia A100 GPUs
186
+
187
+ #### Software
188
+
189
+ PyTorch 2.0,
190
+ xFormers 0.0.18
191
+
192
+ **BibTeX**
193
+
194
+ ```
195
+ @misc{oquab2023dinov2,
196
+ title={DINOv2: Learning Robust Visual Features without Supervision},
197
+ author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
198
+ journal={arXiv:2304.07193},
199
+ year={2023}
200
+ }
201
+ ```
README.md CHANGED
@@ -1,12 +1,501 @@
1
  ---
2
- title: DinoV2 Semantic Segmentation
3
- emoji: 📊
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 3.44.1
8
  app_file: app.py
9
- pinned: false
 
10
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DinoV2_Semantic_Segmentation
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 3.42.0
6
  ---
7
+ # DINOv2: Learning Robust Visual Features without Supervision
8
+
9
+ **[Meta AI Research, FAIR](https://ai.facebook.com/research/)**
10
+
11
+ Maxime Oquab,
12
+ Timothée Darcet,
13
+ Théo Moutakanni,
14
+ Huy V. Vo,
15
+ Marc Szafraniec,
16
+ Vasil Khalidov,
17
+ Patrick Labatut,
18
+ Armand Joulin,
19
+ Piotr Bojanowski
20
+
21
+ [[`Paper`](https://arxiv.org/abs/2304.07193)] [[`Blog`](https://ai.facebook.com/blog/dino-v2-computer-vision-self-supervised-learning/)] [[`Demo`](https://dinov2.metademolab.com)] [[`BibTeX`](#citing-dinov2)]
22
+
23
+ PyTorch implementation and pretrained models for DINOv2. For details, see the paper: **[DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)**.
24
+
25
+ DINOv2 models produce high-performance visual features that can be directly employed with classifiers as simple as linear layers on a variety of computer vision tasks; these visual features are robust and perform well across domains without any requirement for fine-tuning. The models were pretrained on a dataset of 142 M images without using any labels or annotations.
26
+
27
+ https://github.com/facebookresearch/dinov2/assets/60359573/f168823e-7922-415a-b429-578badf5c356
28
+
29
+ <div align="center">
30
+ Visualization of the three first principal components of the patch features of all frames, mapped to RGB values.
31
+ </div>
32
+
33
+ ## Pretrained models
34
+
35
+ <table style="margin: auto">
36
+ <thead>
37
+ <tr>
38
+ <th>model</th>
39
+ <th># of<br />params</th>
40
+ <th>ImageNet<br />k-NN</th>
41
+ <th>ImageNet<br />linear</th>
42
+ <th>download</th>
43
+ </tr>
44
+ </thead>
45
+ <tbody>
46
+ <tr>
47
+ <td>ViT-S/14 distilled</td>
48
+ <td align="right">21 M</td>
49
+ <td align="right">79.0%</td>
50
+ <td align="right">81.1%</td>
51
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth">backbone only</a></td>
52
+ </tr>
53
+ <tr>
54
+ <td>ViT-B/14 distilled</td>
55
+ <td align="right">86 M</td>
56
+ <td align="right">82.1%</td>
57
+ <td align="right">84.5%</td>
58
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_pretrain.pth">backbone only</a></td>
59
+ </tr>
60
+ <tr>
61
+ <td>ViT-L/14 distilled</td>
62
+ <td align="right">300 M</td>
63
+ <td align="right">83.5%</td>
64
+ <td align="right">86.3%</td>
65
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth">backbone only</a></td>
66
+ </tr>
67
+ <tr>
68
+ <td>ViT-g/14</td>
69
+ <td align="right">1,100 M</td>
70
+ <td align="right">83.5%</td>
71
+ <td align="right">86.5%</td>
72
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth">backbone only</a></td>
73
+ </tr>
74
+ </tbody>
75
+ </table>
76
+
77
+ ### Pretrained backbones (via PyTorch Hub)
78
+
79
+ Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install PyTorch (the only required dependency for loading the model). Installing PyTorch with CUDA support is strongly recommended.
80
+
81
+ A corresponding [model card](MODEL_CARD.md) is included in the repository.
82
+
83
+ ```python
84
+ import torch
85
+
86
+ dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')
87
+ dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
88
+ dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
89
+ dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
90
+ ```
91
+
92
+ ### Pretrained heads - Image classification
93
+
94
+ <table style="margin: auto">
95
+ <thead>
96
+ <tr>
97
+ <th rowspan="2">backbone</th>
98
+ <th>download</th>
99
+ </tr>
100
+ <tr>
101
+ <th>ImageNet</th>
102
+ </tr>
103
+ </thead>
104
+ <tbody>
105
+ <tr>
106
+ <td>ViT-S/14 distilled</td>
107
+ <td>
108
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">1 layer</a>,
109
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear4_head.pth">4 layers</a>)
110
+ </td>
111
+ </tr>
112
+ <tr>
113
+ <td>ViT-B/14 distilled</td>
114
+ <td>
115
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">1 layer</a>,
116
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear4_head.pth">4 layers</a>)
117
+ </tr>
118
+ <tr>
119
+ <td>ViT-L/14 distilled</td>
120
+ <td>
121
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">1 layer</a>,
122
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear4_head.pth">4 layers</a>)
123
+ </tr>
124
+ <tr>
125
+ <td>ViT-g/14</td>
126
+ <td>
127
+ linear head (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">1 layer</a>,
128
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear4_head.pth">4 layers</a>)
129
+ </tr>
130
+ </tbody>
131
+ </table>
132
+
133
+ The (full) classifier models can be loaded via PyTorch Hub:
134
+
135
+ ```python
136
+ import torch
137
+
138
+ dinov2_vits14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14_lc')
139
+ dinov2_vitb14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_lc')
140
+ dinov2_vitl14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14_lc')
141
+ dinov2_vitg14_lc = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14_lc')
142
+ ```
143
+
144
+ ### Pretrained heads - Depth estimation
145
+
146
+ <table style="margin: auto">
147
+ <thead>
148
+ <tr>
149
+ <th rowspan="2">backbone</th>
150
+ <th colspan="2">download head</th>
151
+ </tr>
152
+ <tr>
153
+ <th>NYUd</th>
154
+ <th>KITTI</th>
155
+ </tr>
156
+ </thead>
157
+ <tbody>
158
+ <tr>
159
+ <td>ViT-S/14 distilled</td>
160
+ <td>
161
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_linear_head.pth">1 layer</a>,
162
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_linear4_head.pth">4 layers</a>),
163
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_nyu_dpt_head.pth">DPT</a>
164
+ </td>
165
+ <td>
166
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_linear_head.pth">1 layer</a>,
167
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_linear4_head.pth">4 layers</a>),
168
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_kitti_dpt_head.pth">DPT</a>
169
+ </td>
170
+ </tr>
171
+ <tr>
172
+ <td>ViT-B/14 distilled</td>
173
+ <td>
174
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">1 layer</a>,
175
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_nyu_linear4_head.pth">4 layers</a>),
176
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_nyu_dpt_head.pth">DPT</a>
177
+ </td>
178
+ <td>
179
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_linear_head.pth">1 layer</a>,
180
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_linear4_head.pth">4 layers</a>),
181
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_kitti_dpt_head.pth">DPT</a>
182
+ </td>
183
+ </tr>
184
+ <tr>
185
+ <td>ViT-L/14 distilled</td>
186
+ <td>
187
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">1 layer</a>,
188
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_nyu_linear4_head.pth">4 layers</a>),
189
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_nyu_dpt_head.pth">DPT</a>
190
+ </td>
191
+ <td>
192
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_linear_head.pth">1 layer</a>,
193
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_linear4_head.pth">4 layers</a>),
194
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_kitti_dpt_head.pth">DPT</a>
195
+ </td>
196
+ </tr>
197
+ <tr>
198
+ <td>ViT-g/14</td>
199
+ <td>
200
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">1 layer</a>,
201
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_nyu_linear4_head.pth">4 layers</a>),
202
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_nyu_dpt_head.pth">DPT</a>
203
+ </td>
204
+ <td>
205
+ linear (<a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_linear_head.pth">1 layer</a>,
206
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_linear4_head.pth">4 layers</a>),
207
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_kitti_dpt_head.pth">DPT</a>
208
+ </td>
209
+ </tr>
210
+ </tbody>
211
+ </table>
212
+
213
+ ### Pretrained heads - Semantic segmentation
214
+
215
+ <table style="margin: auto">
216
+ <thead>
217
+ <tr>
218
+ <th rowspan="2">backbone</th>
219
+ <th>download model</th>
220
+ <th colspan="2">download head</th>
221
+ </tr>
222
+ <tr>
223
+ <th>ADE20K</th>
224
+ <th>ADE20K</th>
225
+ <th>VOC2012</th>
226
+ </tr>
227
+ </thead>
228
+ <tbody>
229
+ <tr>
230
+ <td>ViT-S/14 distilled</td>
231
+ <td></td>
232
+ <td>
233
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_ade20k_linear_head.pth">linear</a>,
234
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_ade20k_ms_head.pth">multi-scale</a>
235
+ </td>
236
+ <td>
237
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_voc2012_linear_head.pth">linear</a>,
238
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_voc2012_ms_head.pth">multi-scale</a>
239
+ </td>
240
+ </tr>
241
+ <tr>
242
+ <td>ViT-B/14 distilled</td>
243
+ <td></td>
244
+ <td>
245
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_ade20k_linear_head.pth">linear</a>,
246
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_ade20k_ms_head.pth">multi-scale</a>
247
+ </td>
248
+ <td>
249
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_voc2012_linear_head.pth">linear</a>,
250
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_voc2012_ms_head.pth">multi-scale</a>
251
+ </td>
252
+ </tr>
253
+ <tr>
254
+ <td>ViT-L/14 distilled</td>
255
+ <td></td>
256
+ <td>
257
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_ade20k_linear_head.pth">linear</a>,
258
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_ade20k_ms_head.pth">multi-scale</a>
259
+ </td>
260
+ <td>
261
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_voc2012_linear_head.pth">linear</a>,
262
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_voc2012_ms_head.pth">multi-scale</a>
263
+ </td>
264
+ </tr>
265
+ <tr>
266
+ <td>ViT-g/14</td>
267
+ <td>
268
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_ade20k_m2f.pth">Mask2Former</a>
269
+ </td>
270
+ <td>
271
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_ade20k_linear_head.pth">linear</a>,
272
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_ade20k_ms_head.pth">multi-scale</a>
273
+ </td>
274
+ <td>
275
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_voc2012_linear_head.pth">linear</a>,
276
+ <a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_voc2012_ms_head.pth">multi-scale</a>
277
+ </td>
278
+ </tr>
279
+ </tbody>
280
+ </table>
281
+
282
+ ## Installation
283
+
284
+ The training and evaluation code requires PyTorch 2.0 and [xFormers](https://github.com/facebookresearch/xformers) 0.0.18 as well as a number of other 3rd party packages. Note that the code has only been tested with the specified versions and also expects a Linux environment. To setup all the required dependencies for training and evaluation, please follow the instructions below:
285
+
286
+ *[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html)* **(Recommended)** - Clone the repository and then create and activate a `dinov2` conda environment using the provided environment definition:
287
+
288
+ ```shell
289
+ conda env create -f conda.yaml
290
+ conda activate dinov2
291
+ ```
292
+
293
+ *[pip](https://pip.pypa.io/en/stable/getting-started/)* - Clone the repository and then use the provided `requirements.txt` to install the dependencies:
294
+
295
+ ```shell
296
+ pip install -r requirements.txt
297
+ ```
298
+
299
+ For dense tasks (depth estimation and semantic segmentation), there are additional dependencies (specific versions of `mmcv` and `mmsegmentation`) which are captured in the `extras` dependency specifications:
300
+
301
+ *[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html)* **(Recommended)**:
302
+
303
+ ```shell
304
+ conda env create -f conda-extras.yaml
305
+ conda activate dinov2-extras
306
+ ```
307
+
308
+ *[pip](https://pip.pypa.io/en/stable/getting-started/)*:
309
+
310
+ ```shell
311
+ pip install -r requirements.txt -r requirements-extras.txt
312
+ ```
313
+
314
+ ## Data preparation
315
+
316
+ ### ImageNet-1k
317
+
318
+ The root directory of the dataset should hold the following contents:
319
+
320
+ - `<ROOT>/test/ILSVRC2012_test_00000001.JPEG`
321
+ - `<ROOT>/test/[..]`
322
+ - `<ROOT>/test/ILSVRC2012_test_00100000.JPEG`
323
+ - `<ROOT>/train/n01440764/n01440764_10026.JPEG`
324
+ - `<ROOT>/train/[...]`
325
+ - `<ROOT>/train/n15075141/n15075141_9993.JPEG`
326
+ - `<ROOT>/val/n01440764/ILSVRC2012_val_00000293.JPEG`
327
+ - `<ROOT>/val/[...]`
328
+ - `<ROOT>/val/n15075141/ILSVRC2012_val_00049174.JPEG`
329
+ - `<ROOT>/labels.txt`
330
+
331
+ The provided dataset implementation expects a few additional metadata files to be present under the extra directory:
332
+
333
+ - `<EXTRA>/class-ids-TRAIN.npy`
334
+ - `<EXTRA>/class-ids-VAL.npy`
335
+ - `<EXTRA>/class-names-TRAIN.npy`
336
+ - `<EXTRA>/class-names-VAL.npy`
337
+ - `<EXTRA>/entries-TEST.npy`
338
+ - `<EXTRA>/entries-TRAIN.npy`
339
+ - `<EXTRA>/entries-VAL.npy`
340
+
341
+ These metadata files can be generated (once) with the following lines of Python code:
342
+
343
+ ```python
344
+ from dinov2.data.datasets import ImageNet
345
+
346
+ for split in ImageNet.Split:
347
+ dataset = ImageNet(split=split, root="<ROOT>", extra="<EXTRA>")
348
+ dataset.dump_extra()
349
+ ```
350
+
351
+ Note that the root and extra directories do not have to be distinct directories.
352
+
353
+ ### ImageNet-22k
354
+
355
+ Please adapt the [dataset class](dinov2/data/datasets/image_net_22k.py) to match your local setup.
356
+
357
+ <br />
358
+
359
+ :warning: To execute the commands provided in the next sections for training and evaluation, the `dinov2` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`.
360
+
361
+ ## Training
362
+
363
+ ### Fast setup: training DINOv2 ViT-L/16 on ImageNet-1k
364
+
365
+ Run DINOv2 training on 4 A100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit:
366
+
367
+ ```shell
368
+ python dinov2/run/train/train.py \
369
+ --nodes 4 \
370
+ --config-file dinov2/configs/train/vitl16_short.yaml \
371
+ --output-dir <PATH/TO/OUTPUT/DIR> \
372
+ train.dataset_path=ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
373
+ ```
374
+
375
+ Training time is approximately 1 day and the resulting checkpoint should reach 81.6% on k-NN eval and 82.9% on linear eval.
376
+
377
+ The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation.
378
+
379
+ ### Long setup: training DINOv2 ViT-L/14 on ImageNet-22k
380
+
381
+ Run DINOv2 training on 12 A100-80GB nodes (96 GPUs) in a SLURM cluster environment with submitit:
382
+
383
+ ```shell
384
+ python dinov2/run/train/train.py \
385
+ --nodes 12 \
386
+ --config-file dinov2/configs/train/vitl14.yaml \
387
+ --output-dir <PATH/TO/OUTPUT/DIR> \
388
+ train.dataset_path=ImageNet22k:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
389
+ ```
390
+
391
+ Training time is approximately 3.3 days and the resulting checkpoint should reach 82.0% on k-NN eval and 84.5% on linear eval.
392
+
393
+ The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation.
394
+
395
+
396
+ ## Evaluation
397
+
398
+ The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node:
399
+
400
+ ### k-NN classification on ImageNet-1k
401
+
402
+ ```shell
403
+ python dinov2/run/eval/knn.py \
404
+ --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
405
+ --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
406
+ --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/knn \
407
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
408
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
409
+ ```
410
+
411
+ ### Logistic regression classification on ImageNet-1k
412
+
413
+ ```shell
414
+ python dinov2/run/eval/log_regression.py \
415
+ --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
416
+ --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
417
+ --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/logreg \
418
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
419
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
420
+ ```
421
+
422
+ ### Linear classification with data augmentation on ImageNet-1k
423
+
424
+ ```shell
425
+ python dinov2/run/eval/linear.py \
426
+ --config-file <PATH/TO/OUTPUT/DIR>/config.yaml \
427
+ --pretrained-weights <PATH/TO/OUTPUT/DIR>/eval/training_24999/teacher_checkpoint.pth \
428
+ --output-dir <PATH/TO/OUTPUT/DIR>/eval/training_24999/linear \
429
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
430
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
431
+ ```
432
+
433
+ We release the weights from evaluating the different models:
434
+
435
+ <table style="margin: auto">
436
+ <tr>
437
+ <th>model</th>
438
+ <th>ImageNet<br />top-1</th>
439
+ <th>linear evaluation</th>
440
+ </tr>
441
+ <tr>
442
+ <td>ViT-S/14 distilled</td>
443
+ <td align="right">81.1%</td>
444
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_linear_head.pth">linear head weights</a></td>
445
+ </tr>
446
+ <tr>
447
+ <td>ViT-B/14 distilled</td>
448
+ <td align="right">84.5%</td>
449
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_linear_head.pth">linear head weights</a></td>
450
+ </tr>
451
+ <tr>
452
+ <td>ViT-L/14 distilled</td>
453
+ <td align="right">86.3%</td>
454
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_linear_head.pth">linear head weights</a></td>
455
+ </tr>
456
+ <tr>
457
+ <td>ViT-g/14</td>
458
+ <td align="right">86.5%</td>
459
+ <td><a href="https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_linear_head.pth">linear head weights</a></td>
460
+ </tr>
461
+ </table>
462
+
463
+ The performance of the provided pretrained model weights can be evaluated as follows on ImageNet-1k:
464
+
465
+ ```shell
466
+ python dinov2/run/eval/linear.py \
467
+ --config-file dinov2/configs/eval/vitg14_pretrain.yaml \
468
+ --pretrained-weights https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth \
469
+ --train-dataset ImageNet:split=TRAIN:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET> \
470
+ --val-dataset ImageNet:split=VAL:root=<PATH/TO/DATASET>:extra=<PATH/TO/DATASET>
471
+ ```
472
+
473
+ ## Notebooks
474
+
475
+ A few notebooks are provided to help the community leverage the models and code:
476
+
477
+ <ul>
478
+ <li><a href="https://github.com/facebookresearch/dinov2/blob/main/notebooks/depth_estimation.ipynb">Depth estimation</a> - How to load and use the depth heads in combination with a matching backbone via mmcv</li>
479
+ <li><a href="https://github.com/facebookresearch/dinov2/blob/main/notebooks/semantic_segmentation.ipynb">Semantic segmentation</a> - How to load and use the segmentation heads in combination with a matching backbone via mmcv, and also how to load and use the Mask2Former-based segmentation model trained on ADE20K</li>
480
+ </ul>
481
+
482
+ ## License
483
+
484
+ DINOv2 code and model weights are released under the Apache License 2.0. See [LICENSE](LICENSE) for additional details.
485
+
486
+ ## Contributing
487
+
488
+ See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md).
489
+
490
+ ## Citing DINOv2
491
+
492
+ If you find this repository useful, please consider giving a star :star: and citation :t-rex::
493
 
494
+ ```
495
+ @misc{oquab2023dinov2,
496
+ title={DINOv2: Learning Robust Visual Features without Supervision},
497
+ author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr},
498
+ journal={arXiv:2304.07193},
499
+ year={2023}
500
+ }
501
+ ```
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.system("pip uninstall -y mmcv-full")
4
+ os.system("pip uninstall -y mmsegmentation")
5
+ os.system("pip install ./mmcv_full-1.5.0-cp310-cp310-linux_x86_64.whl")
6
+ os.system("pip install -r requirements-extras.txt")
7
+ # os.system("cp /home/user/data/dinov2_vitg14_ade20k_m2f.pth /home/user/.cache/torch/hub/checkpoints/dinov2_vitg14_ade20k_m2f.pth")
8
+
9
+ import gradio as gr
10
+
11
+ import base64
12
+ import cv2
13
+ import math
14
+ import itertools
15
+ from functools import partial
16
+ from PIL import Image
17
+ import numpy as np
18
+ import pandas as pd
19
+
20
+ import dinov2.eval.segmentation.utils.colormaps as colormaps
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from mmseg.apis import init_segmentor, inference_segmentor
25
+
26
+ import dinov2.eval.segmentation.models
27
+ import dinov2.eval.segmentation_m2f.models.segmentors
28
+
29
+ import urllib
30
+
31
+ import mmcv
32
+ from mmcv.runner import load_checkpoint
33
+
34
+ model = None
35
+ model_loaded = False
36
+
37
+ DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
38
+ CONFIG_URL = f"{DINOV2_BASE_URL}/dinov2_vitg14/dinov2_vitg14_ade20k_m2f_config.py"
39
+ CHECKPOINT_URL = f"{DINOV2_BASE_URL}/dinov2_vitg14/dinov2_vitg14_ade20k_m2f.pth"
40
+
41
+
42
+ def load_config_from_url(url: str) -> str:
43
+ with urllib.request.urlopen(url) as f:
44
+ return f.read().decode()
45
+
46
+
47
+ cfg_str = load_config_from_url(CONFIG_URL)
48
+ cfg = mmcv.Config.fromstring(cfg_str, file_format=".py")
49
+
50
+
51
+ DATASET_COLORMAPS = {
52
+ "ade20k": colormaps.ADE20K_COLORMAP,
53
+ "voc2012": colormaps.VOC2012_COLORMAP,
54
+ }
55
+
56
+ model = init_segmentor(cfg)
57
+ load_checkpoint(model, CHECKPOINT_URL, map_location="cpu")
58
+ model.cuda()
59
+ model.eval()
60
+
61
+ class CenterPadding(torch.nn.Module):
62
+ def __init__(self, multiple):
63
+ super().__init__()
64
+ self.multiple = multiple
65
+
66
+ def _get_pad(self, size):
67
+ new_size = math.ceil(size / self.multiple) * self.multiple
68
+ pad_size = new_size - size
69
+ pad_size_left = pad_size // 2
70
+ pad_size_right = pad_size - pad_size_left
71
+ return pad_size_left, pad_size_right
72
+
73
+ @torch.inference_mode()
74
+ def forward(self, x):
75
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
76
+ output = F.pad(x, pads)
77
+ return output
78
+
79
+
80
+ def create_segmenter(cfg, backbone_model):
81
+ model = init_segmentor(cfg)
82
+ model.backbone.forward = partial(
83
+ backbone_model.get_intermediate_layers,
84
+ n=cfg.model.backbone.out_indices,
85
+ reshape=True,
86
+ )
87
+ if hasattr(backbone_model, "patch_size"):
88
+ model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone_model.patch_size)(x[0]))
89
+ model.init_weights()
90
+ return model
91
+
92
+
93
+ def render_segmentation(segmentation_logits, dataset):
94
+ colormap = DATASET_COLORMAPS[dataset]
95
+ colormap_array = np.array(colormap, dtype=np.uint8)
96
+ segmentation_logits += 1
97
+ segmentation_values = colormap_array[segmentation_logits]
98
+ unique_labels = np.unique(segmentation_logits)
99
+
100
+ colormap_array = colormap_array[unique_labels]
101
+ df = pd.read_csv("labelmap.txt", sep="\t")
102
+
103
+ html_output = '<div style="display: flex; flex-wrap: wrap;">'
104
+ import matplotlib.pyplot as plt
105
+
106
+ for idx, color in enumerate(colormap_array):
107
+ color_box = np.zeros((20, 20, 3), dtype=np.uint8)
108
+ color_box[:, :] = color
109
+ color_box = cv2.cvtColor(color_box, cv2.COLOR_RGB2BGR)
110
+ _, img_data = cv2.imencode(".jpg", color_box)
111
+ img_base64 = base64.b64encode(img_data).decode("utf-8")
112
+ img_data_uri = f"data:image/jpg;base64,{img_base64}"
113
+ html_output += f'<div style="margin: 10px;"><img src="{img_data_uri}" /><p>{df.iloc[unique_labels[idx]-1]["Name"]}</p></div>'
114
+
115
+ html_output += "</div>"
116
+
117
+ return Image.fromarray(segmentation_values), html_output
118
+
119
+
120
+ def predict(image_file):
121
+ array = np.array(image_file)[:, :, ::-1] # BGR
122
+ segmentation_logits = inference_segmentor(model, array)[0]
123
+ segmented_image, html_output = render_segmentation(segmentation_logits, "ade20k")
124
+ return np.array(segmented_image), html_output
125
+
126
+ description = "Gradio demo for Semantic segmentation. To use it, simply upload your image"
127
+
128
+ demo = gr.Interface(
129
+ title="Semantic Segmentation - DinoV2",
130
+ fn=predict,
131
+ inputs=gr.inputs.Image(),
132
+ outputs=[gr.outputs.Image(type="numpy"), gr.outputs.HTML()],
133
+ examples=["example_1.jpg", "example_2.jpg"],
134
+ description=description,
135
+ )
136
+
137
+ demo.launch()
conda-extras.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dinov2-extras
2
+ channels:
3
+ - defaults
4
+ - pytorch
5
+ - nvidia
6
+ - xformers
7
+ - conda-forge
8
+ dependencies:
9
+ - python=3.9
10
+ - pytorch::pytorch=2.0.0
11
+ - pytorch::pytorch-cuda=11.7.0
12
+ - pytorch::torchvision=0.15.0
13
+ - omegaconf
14
+ - torchmetrics=0.10.3
15
+ - fvcore
16
+ - iopath
17
+ - xformers::xformers=0.0.18
18
+ - pip
19
+ - pip:
20
+ - git+https://github.com/facebookincubator/submitit
21
+ - --extra-index-url https://pypi.nvidia.com
22
+ - cuml-cu11
23
+ - mmcv-full==1.5.0
24
+ - mmsegmentation==0.27.0
conda.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dinov2
2
+ channels:
3
+ - defaults
4
+ - pytorch
5
+ - nvidia
6
+ - xformers
7
+ - conda-forge
8
+ dependencies:
9
+ - python=3.9
10
+ - pytorch::pytorch=2.0.0
11
+ - pytorch::pytorch-cuda=11.7.0
12
+ - pytorch::torchvision=0.15.0
13
+ - omegaconf
14
+ - torchmetrics=0.10.3
15
+ - fvcore
16
+ - iopath
17
+ - xformers::xformers=0.0.18
18
+ - pip
19
+ - pip:
20
+ - git+https://github.com/facebookincubator/submitit
21
+ - --extra-index-url https://pypi.nvidia.com
22
+ - cuml-cu11
demo.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ REPO_PATH="."
3
+ sys.path.append("/dino_v2")
4
+
5
+ import math
6
+ import itertools
7
+ from functools import partial
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from mmseg.apis import init_segmentor, inference_segmentor
12
+
13
+ import dinov2.eval.segmentation.models
14
+
15
+
16
+ class CenterPadding(torch.nn.Module):
17
+ def __init__(self, multiple):
18
+ super().__init__()
19
+ self.multiple = multiple
20
+
21
+ def _get_pad(self, size):
22
+ new_size = math.ceil(size / self.multiple) * self.multiple
23
+ pad_size = new_size - size
24
+ pad_size_left = pad_size // 2
25
+ pad_size_right = pad_size - pad_size_left
26
+ return pad_size_left, pad_size_right
27
+
28
+ @torch.inference_mode()
29
+ def forward(self, x):
30
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
31
+ output = F.pad(x, pads)
32
+ return output
33
+
34
+
35
+ def create_segmenter(cfg, backbone_model):
36
+ model = init_segmentor(cfg)
37
+ model.backbone.forward = partial(
38
+ backbone_model.get_intermediate_layers,
39
+ n=cfg.model.backbone.out_indices,
40
+ reshape=True,
41
+ )
42
+ if hasattr(backbone_model, "patch_size"):
43
+ model.backbone.register_forward_pre_hook(lambda _, x: CenterPadding(backbone_model.patch_size)(x[0]))
44
+ model.init_weights()
45
+ return model
46
+
47
+ # BACKBONE_SIZE = "small" # in ("small", "base", "large" or "giant")
48
+
49
+
50
+ # backbone_archs = {
51
+ # "small": "vits14",
52
+ # "base": "vitb14",
53
+ # "large": "vitl14",
54
+ # "giant": "vitg14",
55
+ # }
56
+ # backbone_arch = backbone_archs[BACKBONE_SIZE]
57
+ # backbone_name = f"dinov2_{backbone_arch}"
58
+
59
+ # backbone_model = torch.hub.load(repo_or_dir="facebookresearch/dinov2", model=backbone_name)
60
+ # backbone_model.eval()
61
+ # backbone_model.cuda()
62
+
63
+ import urllib
64
+
65
+ import mmcv
66
+ from mmcv.runner import load_checkpoint
67
+
68
+
69
+ def load_config_from_url(url: str) -> str:
70
+ with urllib.request.urlopen(url) as f:
71
+ return f.read().decode()
72
+
73
+
74
+ # HEAD_SCALE_COUNT = 3 # more scales: slower but better results, in (1,2,3,4,5)
75
+ # HEAD_DATASET = "voc2012" # in ("ade20k", "voc2012")
76
+ # HEAD_TYPE = "ms" # in ("ms, "linear")
77
+
78
+
79
+ DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
80
+ # head_config_url = f"{DINOV2_BASE_URL}/{backbone_name}/{backbone_name}_{HEAD_DATASET}_{HEAD_TYPE}_config.py"
81
+ # head_checkpoint_url = f"{DINOV2_BASE_URL}/{backbone_name}/{backbone_name}_{HEAD_DATASET}_{HEAD_TYPE}_head.pth"
82
+
83
+ # cfg_str = load_config_from_url(head_config_url)
84
+ # cfg = mmcv.Config.fromstring(cfg_str, file_format=".py")
85
+ # if HEAD_TYPE == "ms":
86
+ # cfg.data.test.pipeline[1]["img_ratios"] = cfg.data.test.pipeline[1]["img_ratios"][:HEAD_SCALE_COUNT]
87
+ # print("scales:", cfg.data.test.pipeline[1]["img_ratios"])
88
+
89
+ # model = create_segmenter(cfg, backbone_model=backbone_model)
90
+ # load_checkpoint(model, head_checkpoint_url, map_location="cpu")
91
+ # model.cuda()
92
+ # model.eval()
93
+
94
+ import urllib
95
+
96
+ from PIL import Image
97
+
98
+
99
+ def load_image_from_url(url: str) -> Image:
100
+ with urllib.request.urlopen(url) as f:
101
+ return Image.open(f).convert("RGB")
102
+
103
+
104
+ EXAMPLE_IMAGE_URL = "https://dl.fbaipublicfiles.com/dinov2/images/example.jpg"
105
+
106
+
107
+ # image = load_image_from_url(EXAMPLE_IMAGE_URL)
108
+ image = Image.open("bridge_2.JPG").convert("RGB")
109
+
110
+ image.show()
111
+
112
+ import numpy as np
113
+
114
+ import dinov2.eval.segmentation.utils.colormaps as colormaps
115
+
116
+
117
+ DATASET_COLORMAPS = {
118
+ "ade20k": colormaps.ADE20K_COLORMAP,
119
+ "voc2012": colormaps.VOC2012_COLORMAP,
120
+ }
121
+
122
+
123
+ def render_segmentation(segmentation_logits, dataset):
124
+ colormap = DATASET_COLORMAPS[dataset]
125
+ colormap_array = np.array(colormap, dtype=np.uint8)
126
+ print(len(colormap))
127
+ segmentation_values = colormap_array[segmentation_logits + 1]
128
+ return Image.fromarray(segmentation_values)
129
+
130
+
131
+ # array = np.array(image)[:, :, ::-1] # BGR
132
+ # segmentation_logits = inference_segmentor(model, array)[0]
133
+ # segmented_image = render_segmentation(segmentation_logits, HEAD_DATASET)
134
+ # segmented_image.save("output.jpg")
135
+
136
+ import dinov2.eval.segmentation_m2f.models.segmentors
137
+
138
+ CONFIG_URL = f"{DINOV2_BASE_URL}/dinov2_vitg14/dinov2_vitg14_ade20k_m2f_config.py"
139
+ CHECKPOINT_URL = f"{DINOV2_BASE_URL}/dinov2_vitg14/dinov2_vitg14_ade20k_m2f.pth"
140
+
141
+ cfg_str = load_config_from_url(CONFIG_URL)
142
+ cfg = mmcv.Config.fromstring(cfg_str, file_format=".py")
143
+
144
+ model = init_segmentor(cfg)
145
+ load_checkpoint(model, CHECKPOINT_URL, map_location="cpu")
146
+ model.cuda()
147
+ model.eval()
148
+
149
+ array = np.array(image)[:, :, ::-1] # BGR
150
+ segmentation_logits = inference_segmentor(model, array)[0]
151
+ print(np.unique(segmentation_logits, return_counts=True))
152
+ segmented_image = render_segmentation(segmentation_logits, "ade20k")
153
+ segmented_image.save("output.jpg")
dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
dinov2/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (187 Bytes). View file
 
dinov2/configs/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import pathlib
7
+
8
+ from omegaconf import OmegaConf
9
+
10
+
11
+ def load_config(config_name: str):
12
+ config_filename = config_name + ".yaml"
13
+ return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename)
14
+
15
+
16
+ dinov2_default_config = load_config("ssl_default_config")
17
+
18
+
19
+ def load_and_merge_config(config_name: str):
20
+ default_config = OmegaConf.create(dinov2_default_config)
21
+ loaded_config = load_config(config_name)
22
+ return OmegaConf.merge(default_config, loaded_config)
dinov2/configs/eval/vitb14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_base
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
dinov2/configs/eval/vitg14_pretrain.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_giant2
3
+ patch_size: 14
4
+ ffn_layer: swiglufused
5
+ crops:
6
+ global_crops_size: 518 # this is to set up the position embeddings properly
7
+ local_crops_size: 98
dinov2/configs/eval/vitl14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_large
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
dinov2/configs/eval/vits14_pretrain.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ student:
2
+ arch: vit_small
3
+ patch_size: 14
4
+ crops:
5
+ global_crops_size: 518 # this is to set up the position embeddings properly
6
+ local_crops_size: 98
dinov2/configs/ssl_default_config.yaml ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ WEIGHTS: ''
3
+ compute_precision:
4
+ grad_scaler: true
5
+ teacher:
6
+ backbone:
7
+ sharding_strategy: SHARD_GRAD_OP
8
+ mixed_precision:
9
+ param_dtype: fp16
10
+ reduce_dtype: fp16
11
+ buffer_dtype: fp32
12
+ dino_head:
13
+ sharding_strategy: SHARD_GRAD_OP
14
+ mixed_precision:
15
+ param_dtype: fp16
16
+ reduce_dtype: fp16
17
+ buffer_dtype: fp32
18
+ ibot_head:
19
+ sharding_strategy: SHARD_GRAD_OP
20
+ mixed_precision:
21
+ param_dtype: fp16
22
+ reduce_dtype: fp16
23
+ buffer_dtype: fp32
24
+ student:
25
+ backbone:
26
+ sharding_strategy: SHARD_GRAD_OP
27
+ mixed_precision:
28
+ param_dtype: fp16
29
+ reduce_dtype: fp16
30
+ buffer_dtype: fp32
31
+ dino_head:
32
+ sharding_strategy: SHARD_GRAD_OP
33
+ mixed_precision:
34
+ param_dtype: fp16
35
+ reduce_dtype: fp32
36
+ buffer_dtype: fp32
37
+ ibot_head:
38
+ sharding_strategy: SHARD_GRAD_OP
39
+ mixed_precision:
40
+ param_dtype: fp16
41
+ reduce_dtype: fp32
42
+ buffer_dtype: fp32
43
+ dino:
44
+ loss_weight: 1.0
45
+ head_n_prototypes: 65536
46
+ head_bottleneck_dim: 256
47
+ head_nlayers: 3
48
+ head_hidden_dim: 2048
49
+ koleo_loss_weight: 0.1
50
+ ibot:
51
+ loss_weight: 1.0
52
+ mask_sample_probability: 0.5
53
+ mask_ratio_min_max:
54
+ - 0.1
55
+ - 0.5
56
+ separate_head: false
57
+ head_n_prototypes: 65536
58
+ head_bottleneck_dim: 256
59
+ head_nlayers: 3
60
+ head_hidden_dim: 2048
61
+ train:
62
+ batch_size_per_gpu: 64
63
+ dataset_path: ImageNet:split=TRAIN
64
+ output_dir: .
65
+ saveckp_freq: 20
66
+ seed: 0
67
+ num_workers: 10
68
+ OFFICIAL_EPOCH_LENGTH: 1250
69
+ cache_dataset: true
70
+ centering: "centering" # or "sinkhorn_knopp"
71
+ student:
72
+ arch: vit_large
73
+ patch_size: 16
74
+ drop_path_rate: 0.3
75
+ layerscale: 1.0e-05
76
+ drop_path_uniform: true
77
+ pretrained_weights: ''
78
+ ffn_layer: "mlp"
79
+ block_chunks: 0
80
+ qkv_bias: true
81
+ proj_bias: true
82
+ ffn_bias: true
83
+ teacher:
84
+ momentum_teacher: 0.992
85
+ final_momentum_teacher: 1
86
+ warmup_teacher_temp: 0.04
87
+ teacher_temp: 0.07
88
+ warmup_teacher_temp_epochs: 30
89
+ optim:
90
+ epochs: 100
91
+ weight_decay: 0.04
92
+ weight_decay_end: 0.4
93
+ base_lr: 0.004 # learning rate for a batch size of 1024
94
+ lr: 0. # will be set after applying scaling rule
95
+ warmup_epochs: 10
96
+ min_lr: 1.0e-06
97
+ clip_grad: 3.0
98
+ freeze_last_layer_epochs: 1
99
+ scaling_rule: sqrt_wrt_1024
100
+ patch_embed_lr_mult: 0.2
101
+ layerwise_decay: 0.9
102
+ adamw_beta1: 0.9
103
+ adamw_beta2: 0.999
104
+ crops:
105
+ global_crops_scale:
106
+ - 0.32
107
+ - 1.0
108
+ local_crops_number: 8
109
+ local_crops_scale:
110
+ - 0.05
111
+ - 0.32
112
+ global_crops_size: 224
113
+ local_crops_size: 96
114
+ evaluation:
115
+ eval_period_iterations: 12500
dinov2/configs/train/vitg14.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dino:
2
+ head_n_prototypes: 131072
3
+ head_bottleneck_dim: 384
4
+ ibot:
5
+ separate_head: true
6
+ head_n_prototypes: 131072
7
+ train:
8
+ batch_size_per_gpu: 12
9
+ dataset_path: ImageNet22k
10
+ centering: sinkhorn_knopp
11
+ student:
12
+ arch: vit_giant2
13
+ patch_size: 14
14
+ drop_path_rate: 0.4
15
+ ffn_layer: swiglufused
16
+ block_chunks: 4
17
+ teacher:
18
+ momentum_teacher: 0.994
19
+ optim:
20
+ epochs: 500
21
+ weight_decay_end: 0.2
22
+ base_lr: 2.0e-04 # learning rate for a batch size of 1024
23
+ warmup_epochs: 80
24
+ layerwise_decay: 1.0
25
+ crops:
26
+ local_crops_size: 98
dinov2/configs/train/vitl14.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dino:
2
+ head_n_prototypes: 131072
3
+ head_bottleneck_dim: 384
4
+ ibot:
5
+ separate_head: true
6
+ head_n_prototypes: 131072
7
+ train:
8
+ batch_size_per_gpu: 32
9
+ dataset_path: ImageNet22k
10
+ centering: sinkhorn_knopp
11
+ student:
12
+ arch: vit_large
13
+ patch_size: 14
14
+ drop_path_rate: 0.4
15
+ ffn_layer: swiglufused
16
+ block_chunks: 4
17
+ teacher:
18
+ momentum_teacher: 0.994
19
+ optim:
20
+ epochs: 500
21
+ weight_decay_end: 0.2
22
+ base_lr: 2.0e-04 # learning rate for a batch size of 1024
23
+ warmup_epochs: 80
24
+ layerwise_decay: 1.0
25
+ crops:
26
+ local_crops_size: 98
dinov2/configs/train/vitl16_short.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # this corresponds to the default config
2
+ train:
3
+ dataset_path: ImageNet:split=TRAIN
4
+ batch_size_per_gpu: 64
5
+ student:
6
+ block_chunks: 4
dinov2/data/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .adapters import DatasetWithEnumeratedTargets
7
+ from .loaders import make_data_loader, make_dataset, SamplerType
8
+ from .collate import collate_data_and_cast
9
+ from .masking import MaskingGenerator
10
+ from .augmentations import DataAugmentationDINO
dinov2/data/adapters.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, Tuple
7
+
8
+ from torch.utils.data import Dataset
9
+
10
+
11
+ class DatasetWithEnumeratedTargets(Dataset):
12
+ def __init__(self, dataset):
13
+ self._dataset = dataset
14
+
15
+ def get_image_data(self, index: int) -> bytes:
16
+ return self._dataset.get_image_data(index)
17
+
18
+ def get_target(self, index: int) -> Tuple[Any, int]:
19
+ target = self._dataset.get_target(index)
20
+ return (index, target)
21
+
22
+ def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]:
23
+ image, target = self._dataset[index]
24
+ target = index if target is None else target
25
+ return image, (index, target)
26
+
27
+ def __len__(self) -> int:
28
+ return len(self._dataset)
dinov2/data/augmentations.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from torchvision import transforms
9
+
10
+ from .transforms import (
11
+ GaussianBlur,
12
+ make_normalize_transform,
13
+ )
14
+
15
+
16
+ logger = logging.getLogger("dinov2")
17
+
18
+
19
+ class DataAugmentationDINO(object):
20
+ def __init__(
21
+ self,
22
+ global_crops_scale,
23
+ local_crops_scale,
24
+ local_crops_number,
25
+ global_crops_size=224,
26
+ local_crops_size=96,
27
+ ):
28
+ self.global_crops_scale = global_crops_scale
29
+ self.local_crops_scale = local_crops_scale
30
+ self.local_crops_number = local_crops_number
31
+ self.global_crops_size = global_crops_size
32
+ self.local_crops_size = local_crops_size
33
+
34
+ logger.info("###################################")
35
+ logger.info("Using data augmentation parameters:")
36
+ logger.info(f"global_crops_scale: {global_crops_scale}")
37
+ logger.info(f"local_crops_scale: {local_crops_scale}")
38
+ logger.info(f"local_crops_number: {local_crops_number}")
39
+ logger.info(f"global_crops_size: {global_crops_size}")
40
+ logger.info(f"local_crops_size: {local_crops_size}")
41
+ logger.info("###################################")
42
+
43
+ # random resized crop and flip
44
+ self.geometric_augmentation_global = transforms.Compose(
45
+ [
46
+ transforms.RandomResizedCrop(
47
+ global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
48
+ ),
49
+ transforms.RandomHorizontalFlip(p=0.5),
50
+ ]
51
+ )
52
+
53
+ self.geometric_augmentation_local = transforms.Compose(
54
+ [
55
+ transforms.RandomResizedCrop(
56
+ local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
57
+ ),
58
+ transforms.RandomHorizontalFlip(p=0.5),
59
+ ]
60
+ )
61
+
62
+ # color distorsions / blurring
63
+ color_jittering = transforms.Compose(
64
+ [
65
+ transforms.RandomApply(
66
+ [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
67
+ p=0.8,
68
+ ),
69
+ transforms.RandomGrayscale(p=0.2),
70
+ ]
71
+ )
72
+
73
+ global_transfo1_extra = GaussianBlur(p=1.0)
74
+
75
+ global_transfo2_extra = transforms.Compose(
76
+ [
77
+ GaussianBlur(p=0.1),
78
+ transforms.RandomSolarize(threshold=128, p=0.2),
79
+ ]
80
+ )
81
+
82
+ local_transfo_extra = GaussianBlur(p=0.5)
83
+
84
+ # normalization
85
+ self.normalize = transforms.Compose(
86
+ [
87
+ transforms.ToTensor(),
88
+ make_normalize_transform(),
89
+ ]
90
+ )
91
+
92
+ self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize])
93
+ self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize])
94
+ self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize])
95
+
96
+ def __call__(self, image):
97
+ output = {}
98
+
99
+ # global crops:
100
+ im1_base = self.geometric_augmentation_global(image)
101
+ global_crop_1 = self.global_transfo1(im1_base)
102
+
103
+ im2_base = self.geometric_augmentation_global(image)
104
+ global_crop_2 = self.global_transfo2(im2_base)
105
+
106
+ output["global_crops"] = [global_crop_1, global_crop_2]
107
+
108
+ # global crops for teacher:
109
+ output["global_crops_teacher"] = [global_crop_1, global_crop_2]
110
+
111
+ # local crops:
112
+ local_crops = [
113
+ self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number)
114
+ ]
115
+ output["local_crops"] = local_crops
116
+ output["offsets"] = ()
117
+
118
+ return output
dinov2/data/collate.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import random
8
+
9
+
10
+ def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None):
11
+ # dtype = torch.half # TODO: Remove
12
+
13
+ n_global_crops = len(samples_list[0][0]["global_crops"])
14
+ n_local_crops = len(samples_list[0][0]["local_crops"])
15
+
16
+ collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list])
17
+
18
+ collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list])
19
+
20
+ B = len(collated_global_crops)
21
+ N = n_tokens
22
+ n_samples_masked = int(B * mask_probability)
23
+ probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1)
24
+ upperbound = 0
25
+ masks_list = []
26
+ for i in range(0, n_samples_masked):
27
+ prob_min = probs[i]
28
+ prob_max = probs[i + 1]
29
+ masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max)))))
30
+ upperbound += int(N * prob_max)
31
+ for i in range(n_samples_masked, B):
32
+ masks_list.append(torch.BoolTensor(mask_generator(0)))
33
+
34
+ random.shuffle(masks_list)
35
+
36
+ collated_masks = torch.stack(masks_list).flatten(1)
37
+ mask_indices_list = collated_masks.flatten().nonzero().flatten()
38
+
39
+ masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks]
40
+
41
+ return {
42
+ "collated_global_crops": collated_global_crops.to(dtype),
43
+ "collated_local_crops": collated_local_crops.to(dtype),
44
+ "collated_masks": collated_masks,
45
+ "mask_indices_list": mask_indices_list,
46
+ "masks_weight": masks_weight,
47
+ "upperbound": upperbound,
48
+ "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long),
49
+ }
dinov2/data/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .image_net import ImageNet
7
+ from .image_net_22k import ImageNet22k
dinov2/data/datasets/decoders.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from io import BytesIO
7
+ from typing import Any
8
+
9
+ from PIL import Image
10
+
11
+
12
+ class Decoder:
13
+ def decode(self) -> Any:
14
+ raise NotImplementedError
15
+
16
+
17
+ class ImageDataDecoder(Decoder):
18
+ def __init__(self, image_data: bytes) -> None:
19
+ self._image_data = image_data
20
+
21
+ def decode(self) -> Image:
22
+ f = BytesIO(self._image_data)
23
+ return Image.open(f).convert(mode="RGB")
24
+
25
+
26
+ class TargetDecoder(Decoder):
27
+ def __init__(self, target: Any):
28
+ self._target = target
29
+
30
+ def decode(self) -> Any:
31
+ return self._target
dinov2/data/datasets/extended.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Any, Tuple
7
+
8
+ from torchvision.datasets import VisionDataset
9
+
10
+ from .decoders import TargetDecoder, ImageDataDecoder
11
+
12
+
13
+ class ExtendedVisionDataset(VisionDataset):
14
+ def __init__(self, *args, **kwargs) -> None:
15
+ super().__init__(*args, **kwargs) # type: ignore
16
+
17
+ def get_image_data(self, index: int) -> bytes:
18
+ raise NotImplementedError
19
+
20
+ def get_target(self, index: int) -> Any:
21
+ raise NotImplementedError
22
+
23
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
24
+ try:
25
+ image_data = self.get_image_data(index)
26
+ image = ImageDataDecoder(image_data).decode()
27
+ except Exception as e:
28
+ raise RuntimeError(f"can not read image for sample {index}") from e
29
+ target = self.get_target(index)
30
+ target = TargetDecoder(target).decode()
31
+
32
+ if self.transforms is not None:
33
+ image, target = self.transforms(image, target)
34
+
35
+ return image, target
36
+
37
+ def __len__(self) -> int:
38
+ raise NotImplementedError
dinov2/data/datasets/image_net.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import csv
7
+ from enum import Enum
8
+ import logging
9
+ import os
10
+ from typing import Callable, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+
14
+ from .extended import ExtendedVisionDataset
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+ _Target = int
19
+
20
+
21
+ class _Split(Enum):
22
+ TRAIN = "train"
23
+ VAL = "val"
24
+ TEST = "test" # NOTE: torchvision does not support the test split
25
+
26
+ @property
27
+ def length(self) -> int:
28
+ split_lengths = {
29
+ _Split.TRAIN: 1_281_167,
30
+ _Split.VAL: 50_000,
31
+ _Split.TEST: 100_000,
32
+ }
33
+ return split_lengths[self]
34
+
35
+ def get_dirname(self, class_id: Optional[str] = None) -> str:
36
+ return self.value if class_id is None else os.path.join(self.value, class_id)
37
+
38
+ def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str:
39
+ dirname = self.get_dirname(class_id)
40
+ if self == _Split.TRAIN:
41
+ basename = f"{class_id}_{actual_index}"
42
+ else: # self in (_Split.VAL, _Split.TEST):
43
+ basename = f"ILSVRC2012_{self.value}_{actual_index:08d}"
44
+ return os.path.join(dirname, basename + ".JPEG")
45
+
46
+ def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]:
47
+ assert self != _Split.TEST
48
+ dirname, filename = os.path.split(image_relpath)
49
+ class_id = os.path.split(dirname)[-1]
50
+ basename, _ = os.path.splitext(filename)
51
+ actual_index = int(basename.split("_")[-1])
52
+ return class_id, actual_index
53
+
54
+
55
+ class ImageNet(ExtendedVisionDataset):
56
+ Target = Union[_Target]
57
+ Split = Union[_Split]
58
+
59
+ def __init__(
60
+ self,
61
+ *,
62
+ split: "ImageNet.Split",
63
+ root: str,
64
+ extra: str,
65
+ transforms: Optional[Callable] = None,
66
+ transform: Optional[Callable] = None,
67
+ target_transform: Optional[Callable] = None,
68
+ ) -> None:
69
+ super().__init__(root, transforms, transform, target_transform)
70
+ self._extra_root = extra
71
+ self._split = split
72
+
73
+ self._entries = None
74
+ self._class_ids = None
75
+ self._class_names = None
76
+
77
+ @property
78
+ def split(self) -> "ImageNet.Split":
79
+ return self._split
80
+
81
+ def _get_extra_full_path(self, extra_path: str) -> str:
82
+ return os.path.join(self._extra_root, extra_path)
83
+
84
+ def _load_extra(self, extra_path: str) -> np.ndarray:
85
+ extra_full_path = self._get_extra_full_path(extra_path)
86
+ return np.load(extra_full_path, mmap_mode="r")
87
+
88
+ def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
89
+ extra_full_path = self._get_extra_full_path(extra_path)
90
+ os.makedirs(self._extra_root, exist_ok=True)
91
+ np.save(extra_full_path, extra_array)
92
+
93
+ @property
94
+ def _entries_path(self) -> str:
95
+ return f"entries-{self._split.value.upper()}.npy"
96
+
97
+ @property
98
+ def _class_ids_path(self) -> str:
99
+ return f"class-ids-{self._split.value.upper()}.npy"
100
+
101
+ @property
102
+ def _class_names_path(self) -> str:
103
+ return f"class-names-{self._split.value.upper()}.npy"
104
+
105
+ def _get_entries(self) -> np.ndarray:
106
+ if self._entries is None:
107
+ self._entries = self._load_extra(self._entries_path)
108
+ assert self._entries is not None
109
+ return self._entries
110
+
111
+ def _get_class_ids(self) -> np.ndarray:
112
+ if self._split == _Split.TEST:
113
+ assert False, "Class IDs are not available in TEST split"
114
+ if self._class_ids is None:
115
+ self._class_ids = self._load_extra(self._class_ids_path)
116
+ assert self._class_ids is not None
117
+ return self._class_ids
118
+
119
+ def _get_class_names(self) -> np.ndarray:
120
+ if self._split == _Split.TEST:
121
+ assert False, "Class names are not available in TEST split"
122
+ if self._class_names is None:
123
+ self._class_names = self._load_extra(self._class_names_path)
124
+ assert self._class_names is not None
125
+ return self._class_names
126
+
127
+ def find_class_id(self, class_index: int) -> str:
128
+ class_ids = self._get_class_ids()
129
+ return str(class_ids[class_index])
130
+
131
+ def find_class_name(self, class_index: int) -> str:
132
+ class_names = self._get_class_names()
133
+ return str(class_names[class_index])
134
+
135
+ def get_image_data(self, index: int) -> bytes:
136
+ entries = self._get_entries()
137
+ actual_index = entries[index]["actual_index"]
138
+
139
+ class_id = self.get_class_id(index)
140
+
141
+ image_relpath = self.split.get_image_relpath(actual_index, class_id)
142
+ image_full_path = os.path.join(self.root, image_relpath)
143
+ with open(image_full_path, mode="rb") as f:
144
+ image_data = f.read()
145
+ return image_data
146
+
147
+ def get_target(self, index: int) -> Optional[Target]:
148
+ entries = self._get_entries()
149
+ class_index = entries[index]["class_index"]
150
+ return None if self.split == _Split.TEST else int(class_index)
151
+
152
+ def get_targets(self) -> Optional[np.ndarray]:
153
+ entries = self._get_entries()
154
+ return None if self.split == _Split.TEST else entries["class_index"]
155
+
156
+ def get_class_id(self, index: int) -> Optional[str]:
157
+ entries = self._get_entries()
158
+ class_id = entries[index]["class_id"]
159
+ return None if self.split == _Split.TEST else str(class_id)
160
+
161
+ def get_class_name(self, index: int) -> Optional[str]:
162
+ entries = self._get_entries()
163
+ class_name = entries[index]["class_name"]
164
+ return None if self.split == _Split.TEST else str(class_name)
165
+
166
+ def __len__(self) -> int:
167
+ entries = self._get_entries()
168
+ assert len(entries) == self.split.length
169
+ return len(entries)
170
+
171
+ def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]:
172
+ labels_full_path = os.path.join(self.root, labels_path)
173
+ labels = []
174
+
175
+ try:
176
+ with open(labels_full_path, "r") as f:
177
+ reader = csv.reader(f)
178
+ for row in reader:
179
+ class_id, class_name = row
180
+ labels.append((class_id, class_name))
181
+ except OSError as e:
182
+ raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e
183
+
184
+ return labels
185
+
186
+ def _dump_entries(self) -> None:
187
+ split = self.split
188
+ if split == ImageNet.Split.TEST:
189
+ dataset = None
190
+ sample_count = split.length
191
+ max_class_id_length, max_class_name_length = 0, 0
192
+ else:
193
+ labels_path = "labels.txt"
194
+ logger.info(f'loading labels from "{labels_path}"')
195
+ labels = self._load_labels(labels_path)
196
+
197
+ # NOTE: Using torchvision ImageFolder for consistency
198
+ from torchvision.datasets import ImageFolder
199
+
200
+ dataset_root = os.path.join(self.root, split.get_dirname())
201
+ dataset = ImageFolder(dataset_root)
202
+ sample_count = len(dataset)
203
+ max_class_id_length, max_class_name_length = -1, -1
204
+ for sample in dataset.samples:
205
+ _, class_index = sample
206
+ class_id, class_name = labels[class_index]
207
+ max_class_id_length = max(len(class_id), max_class_id_length)
208
+ max_class_name_length = max(len(class_name), max_class_name_length)
209
+
210
+ dtype = np.dtype(
211
+ [
212
+ ("actual_index", "<u4"),
213
+ ("class_index", "<u4"),
214
+ ("class_id", f"U{max_class_id_length}"),
215
+ ("class_name", f"U{max_class_name_length}"),
216
+ ]
217
+ )
218
+ entries_array = np.empty(sample_count, dtype=dtype)
219
+
220
+ if split == ImageNet.Split.TEST:
221
+ old_percent = -1
222
+ for index in range(sample_count):
223
+ percent = 100 * (index + 1) // sample_count
224
+ if percent > old_percent:
225
+ logger.info(f"creating entries: {percent}%")
226
+ old_percent = percent
227
+
228
+ actual_index = index + 1
229
+ class_index = np.uint32(-1)
230
+ class_id, class_name = "", ""
231
+ entries_array[index] = (actual_index, class_index, class_id, class_name)
232
+ else:
233
+ class_names = {class_id: class_name for class_id, class_name in labels}
234
+
235
+ assert dataset
236
+ old_percent = -1
237
+ for index in range(sample_count):
238
+ percent = 100 * (index + 1) // sample_count
239
+ if percent > old_percent:
240
+ logger.info(f"creating entries: {percent}%")
241
+ old_percent = percent
242
+
243
+ image_full_path, class_index = dataset.samples[index]
244
+ image_relpath = os.path.relpath(image_full_path, self.root)
245
+ class_id, actual_index = split.parse_image_relpath(image_relpath)
246
+ class_name = class_names[class_id]
247
+ entries_array[index] = (actual_index, class_index, class_id, class_name)
248
+
249
+ logger.info(f'saving entries to "{self._entries_path}"')
250
+ self._save_extra(entries_array, self._entries_path)
251
+
252
+ def _dump_class_ids_and_names(self) -> None:
253
+ split = self.split
254
+ if split == ImageNet.Split.TEST:
255
+ return
256
+
257
+ entries_array = self._load_extra(self._entries_path)
258
+
259
+ max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1
260
+ for entry in entries_array:
261
+ class_index, class_id, class_name = (
262
+ entry["class_index"],
263
+ entry["class_id"],
264
+ entry["class_name"],
265
+ )
266
+ max_class_index = max(int(class_index), max_class_index)
267
+ max_class_id_length = max(len(str(class_id)), max_class_id_length)
268
+ max_class_name_length = max(len(str(class_name)), max_class_name_length)
269
+
270
+ class_count = max_class_index + 1
271
+ class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}")
272
+ class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}")
273
+ for entry in entries_array:
274
+ class_index, class_id, class_name = (
275
+ entry["class_index"],
276
+ entry["class_id"],
277
+ entry["class_name"],
278
+ )
279
+ class_ids_array[class_index] = class_id
280
+ class_names_array[class_index] = class_name
281
+
282
+ logger.info(f'saving class IDs to "{self._class_ids_path}"')
283
+ self._save_extra(class_ids_array, self._class_ids_path)
284
+
285
+ logger.info(f'saving class names to "{self._class_names_path}"')
286
+ self._save_extra(class_names_array, self._class_names_path)
287
+
288
+ def dump_extra(self) -> None:
289
+ self._dump_entries()
290
+ self._dump_class_ids_and_names()
dinov2/data/datasets/image_net_22k.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from dataclasses import dataclass
7
+ from enum import Enum
8
+ from functools import lru_cache
9
+ from gzip import GzipFile
10
+ from io import BytesIO
11
+ from mmap import ACCESS_READ, mmap
12
+ import os
13
+ from typing import Any, Callable, List, Optional, Set, Tuple
14
+ import warnings
15
+
16
+ import numpy as np
17
+
18
+ from .extended import ExtendedVisionDataset
19
+
20
+
21
+ _Labels = int
22
+
23
+ _DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors
24
+
25
+
26
+ @dataclass
27
+ class _ClassEntry:
28
+ block_offset: int
29
+ maybe_filename: Optional[str] = None
30
+
31
+
32
+ @dataclass
33
+ class _Entry:
34
+ class_index: int # noqa: E701
35
+ start_offset: int
36
+ end_offset: int
37
+ filename: str
38
+
39
+
40
+ class _Split(Enum):
41
+ TRAIN = "train"
42
+ VAL = "val"
43
+
44
+ @property
45
+ def length(self) -> int:
46
+ return {
47
+ _Split.TRAIN: 11_797_647,
48
+ _Split.VAL: 561_050,
49
+ }[self]
50
+
51
+ def entries_path(self):
52
+ return f"imagenet21kp_{self.value}.txt"
53
+
54
+
55
+ def _get_tarball_path(class_id: str) -> str:
56
+ return f"{class_id}.tar"
57
+
58
+
59
+ def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int):
60
+ @lru_cache(maxsize=mmap_cache_size)
61
+ def _mmap_tarball(class_id: str) -> mmap:
62
+ tarball_path = _get_tarball_path(class_id)
63
+ tarball_full_path = os.path.join(tarballs_root, tarball_path)
64
+ with open(tarball_full_path) as f:
65
+ return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ)
66
+
67
+ return _mmap_tarball
68
+
69
+
70
+ class ImageNet22k(ExtendedVisionDataset):
71
+ _GZIPPED_INDICES: Set[int] = {
72
+ 841_545,
73
+ 1_304_131,
74
+ 2_437_921,
75
+ 2_672_079,
76
+ 2_795_676,
77
+ 2_969_786,
78
+ 6_902_965,
79
+ 6_903_550,
80
+ 6_903_628,
81
+ 7_432_557,
82
+ 7_432_589,
83
+ 7_813_809,
84
+ 8_329_633,
85
+ 10_296_990,
86
+ 10_417_652,
87
+ 10_492_265,
88
+ 10_598_078,
89
+ 10_782_398,
90
+ 10_902_612,
91
+ 11_203_736,
92
+ 11_342_890,
93
+ 11_397_596,
94
+ 11_589_762,
95
+ 11_705_103,
96
+ 12_936_875,
97
+ 13_289_782,
98
+ }
99
+ Labels = _Labels
100
+
101
+ def __init__(
102
+ self,
103
+ *,
104
+ root: str,
105
+ extra: str,
106
+ transforms: Optional[Callable] = None,
107
+ transform: Optional[Callable] = None,
108
+ target_transform: Optional[Callable] = None,
109
+ mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE,
110
+ ) -> None:
111
+ super().__init__(root, transforms, transform, target_transform)
112
+ self._extra_root = extra
113
+
114
+ entries_path = self._get_entries_path(root)
115
+ self._entries = self._load_extra(entries_path)
116
+
117
+ class_ids_path = self._get_class_ids_path(root)
118
+ self._class_ids = self._load_extra(class_ids_path)
119
+
120
+ self._gzipped_indices = ImageNet22k._GZIPPED_INDICES
121
+ self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size)
122
+
123
+ def _get_entries_path(self, root: Optional[str] = None) -> str:
124
+ return "entries.npy"
125
+
126
+ def _get_class_ids_path(self, root: Optional[str] = None) -> str:
127
+ return "class-ids.npy"
128
+
129
+ def _find_class_ids(self, path: str) -> List[str]:
130
+ class_ids = []
131
+
132
+ with os.scandir(path) as entries:
133
+ for entry in entries:
134
+ root, ext = os.path.splitext(entry.name)
135
+ if ext != ".tar":
136
+ continue
137
+ class_ids.append(root)
138
+
139
+ return sorted(class_ids)
140
+
141
+ def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]:
142
+ root = self.get_root(root)
143
+ entries: List[_Entry] = []
144
+ class_ids = self._find_class_ids(root)
145
+
146
+ for class_index, class_id in enumerate(class_ids):
147
+ path = os.path.join(root, "blocks", f"{class_id}.log")
148
+ class_entries = []
149
+
150
+ try:
151
+ with open(path) as f:
152
+ for line in f:
153
+ line = line.rstrip()
154
+ block, filename = line.split(":")
155
+ block_offset = int(block[6:])
156
+ filename = filename[1:]
157
+
158
+ maybe_filename = None
159
+ if filename != "** Block of NULs **":
160
+ maybe_filename = filename
161
+ _, ext = os.path.splitext(filename)
162
+ # assert ext == ".JPEG"
163
+
164
+ class_entry = _ClassEntry(block_offset, maybe_filename)
165
+ class_entries.append(class_entry)
166
+ except OSError as e:
167
+ raise RuntimeError(f'can not read blocks file "{path}"') from e
168
+
169
+ assert class_entries[-1].maybe_filename is None
170
+
171
+ for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]):
172
+ assert class_entry1.block_offset <= class_entry2.block_offset
173
+ start_offset = 512 * class_entry1.block_offset
174
+ end_offset = 512 * class_entry2.block_offset
175
+ assert class_entry1.maybe_filename is not None
176
+ filename = class_entry1.maybe_filename
177
+ entry = _Entry(class_index, start_offset, end_offset, filename)
178
+ # Skip invalid image files (PIL throws UnidentifiedImageError)
179
+ if filename == "n06470073_47249.JPEG":
180
+ continue
181
+ entries.append(entry)
182
+
183
+ return entries, class_ids
184
+
185
+ def _load_extra(self, extra_path: str) -> np.ndarray:
186
+ extra_root = self._extra_root
187
+ extra_full_path = os.path.join(extra_root, extra_path)
188
+ return np.load(extra_full_path, mmap_mode="r")
189
+
190
+ def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None:
191
+ extra_root = self._extra_root
192
+ extra_full_path = os.path.join(extra_root, extra_path)
193
+ os.makedirs(extra_root, exist_ok=True)
194
+ np.save(extra_full_path, extra_array)
195
+
196
+ @property
197
+ def _tarballs_root(self) -> str:
198
+ return self.root
199
+
200
+ def find_class_id(self, class_index: int) -> str:
201
+ return str(self._class_ids[class_index])
202
+
203
+ def get_image_data(self, index: int) -> bytes:
204
+ entry = self._entries[index]
205
+ class_id = entry["class_id"]
206
+ class_mmap = self._mmap_tarball(class_id)
207
+
208
+ start_offset, end_offset = entry["start_offset"], entry["end_offset"]
209
+ try:
210
+ mapped_data = class_mmap[start_offset:end_offset]
211
+ data = mapped_data[512:] # Skip entry header block
212
+
213
+ if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B):
214
+ assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}"
215
+ with GzipFile(fileobj=BytesIO(data)) as g:
216
+ data = g.read()
217
+ except Exception as e:
218
+ raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e
219
+
220
+ return data
221
+
222
+ def get_target(self, index: int) -> Any:
223
+ return int(self._entries[index]["class_index"])
224
+
225
+ def get_targets(self) -> np.ndarray:
226
+ return self._entries["class_index"]
227
+
228
+ def get_class_id(self, index: int) -> str:
229
+ return str(self._entries[index]["class_id"])
230
+
231
+ def get_class_ids(self) -> np.ndarray:
232
+ return self._entries["class_id"]
233
+
234
+ def __getitem__(self, index: int) -> Tuple[Any, Any]:
235
+ with warnings.catch_warnings():
236
+ warnings.simplefilter("ignore")
237
+ return super().__getitem__(index)
238
+
239
+ def __len__(self) -> int:
240
+ return len(self._entries)
241
+
242
+ def _dump_entries(self, *args, **kwargs) -> None:
243
+ entries, class_ids = self._load_entries_class_ids(*args, **kwargs)
244
+
245
+ max_class_id_length, max_filename_length, max_class_index = -1, -1, -1
246
+ for entry in entries:
247
+ class_id = class_ids[entry.class_index]
248
+ max_class_index = max(entry.class_index, max_class_index)
249
+ max_class_id_length = max(len(class_id), max_class_id_length)
250
+ max_filename_length = max(len(entry.filename), max_filename_length)
251
+
252
+ dtype = np.dtype(
253
+ [
254
+ ("class_index", "<u4"),
255
+ ("class_id", f"U{max_class_id_length}"),
256
+ ("start_offset", "<u4"),
257
+ ("end_offset", "<u4"),
258
+ ("filename", f"U{max_filename_length}"),
259
+ ]
260
+ )
261
+ sample_count = len(entries)
262
+ entries_array = np.empty(sample_count, dtype=dtype)
263
+ for i, entry in enumerate(entries):
264
+ class_index = entry.class_index
265
+ class_id = class_ids[class_index]
266
+ start_offset = entry.start_offset
267
+ end_offset = entry.end_offset
268
+ filename = entry.filename
269
+ entries_array[i] = (
270
+ class_index,
271
+ class_id,
272
+ start_offset,
273
+ end_offset,
274
+ filename,
275
+ )
276
+
277
+ entries_path = self._get_entries_path(*args, **kwargs)
278
+ self._save_extra(entries_array, entries_path)
279
+
280
+ def _dump_class_ids(self, *args, **kwargs) -> None:
281
+ entries_path = self._get_entries_path(*args, **kwargs)
282
+ entries_array = self._load_extra(entries_path)
283
+
284
+ max_class_id_length, max_class_index = -1, -1
285
+ for entry in entries_array:
286
+ class_index, class_id = entry["class_index"], entry["class_id"]
287
+ max_class_index = max(int(class_index), max_class_index)
288
+ max_class_id_length = max(len(str(class_id)), max_class_id_length)
289
+
290
+ class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}")
291
+ for entry in entries_array:
292
+ class_index, class_id = entry["class_index"], entry["class_id"]
293
+ class_ids_array[class_index] = class_id
294
+ class_ids_path = self._get_class_ids_path(*args, **kwargs)
295
+ self._save_extra(class_ids_array, class_ids_path)
296
+
297
+ def _dump_extra(self, *args, **kwargs) -> None:
298
+ self._dump_entries(*args, *kwargs)
299
+ self._dump_class_ids(*args, *kwargs)
300
+
301
+ def dump_extra(self, root: Optional[str] = None) -> None:
302
+ return self._dump_extra(root)
dinov2/data/loaders.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ from enum import Enum
8
+ from typing import Any, Callable, List, Optional, TypeVar
9
+
10
+ import torch
11
+ from torch.utils.data import Sampler
12
+
13
+ from .datasets import ImageNet, ImageNet22k
14
+ from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ class SamplerType(Enum):
21
+ DISTRIBUTED = 0
22
+ EPOCH = 1
23
+ INFINITE = 2
24
+ SHARDED_INFINITE = 3
25
+ SHARDED_INFINITE_NEW = 4
26
+
27
+
28
+ def _make_bool_str(b: bool) -> str:
29
+ return "yes" if b else "no"
30
+
31
+
32
+ def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None):
33
+ def transform(sample):
34
+ image, target = sample
35
+ if image_transform is not None:
36
+ image = image_transform(image)
37
+ if target_transform is not None:
38
+ target = target_transform(target)
39
+ return image, target
40
+
41
+ return transform
42
+
43
+
44
+ def _parse_dataset_str(dataset_str: str):
45
+ tokens = dataset_str.split(":")
46
+
47
+ name = tokens[0]
48
+ kwargs = {}
49
+
50
+ for token in tokens[1:]:
51
+ key, value = token.split("=")
52
+ assert key in ("root", "extra", "split")
53
+ kwargs[key] = value
54
+
55
+ if name == "ImageNet":
56
+ class_ = ImageNet
57
+ if "split" in kwargs:
58
+ kwargs["split"] = ImageNet.Split[kwargs["split"]]
59
+ elif name == "ImageNet22k":
60
+ class_ = ImageNet22k
61
+ else:
62
+ raise ValueError(f'Unsupported dataset "{name}"')
63
+
64
+ return class_, kwargs
65
+
66
+
67
+ def make_dataset(
68
+ *,
69
+ dataset_str: str,
70
+ transform: Optional[Callable] = None,
71
+ target_transform: Optional[Callable] = None,
72
+ ):
73
+ """
74
+ Creates a dataset with the specified parameters.
75
+
76
+ Args:
77
+ dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN).
78
+ transform: A transform to apply to images.
79
+ target_transform: A transform to apply to targets.
80
+
81
+ Returns:
82
+ The created dataset.
83
+ """
84
+ logger.info(f'using dataset: "{dataset_str}"')
85
+
86
+ class_, kwargs = _parse_dataset_str(dataset_str)
87
+ dataset = class_(transform=transform, target_transform=target_transform, **kwargs)
88
+
89
+ logger.info(f"# of dataset samples: {len(dataset):,d}")
90
+
91
+ # Aggregated datasets do not expose (yet) these attributes, so add them.
92
+ if not hasattr(dataset, "transform"):
93
+ setattr(dataset, "transform", transform)
94
+ if not hasattr(dataset, "target_transform"):
95
+ setattr(dataset, "target_transform", target_transform)
96
+
97
+ return dataset
98
+
99
+
100
+ def _make_sampler(
101
+ *,
102
+ dataset,
103
+ type: Optional[SamplerType] = None,
104
+ shuffle: bool = False,
105
+ seed: int = 0,
106
+ size: int = -1,
107
+ advance: int = 0,
108
+ ) -> Optional[Sampler]:
109
+ sample_count = len(dataset)
110
+
111
+ if type == SamplerType.INFINITE:
112
+ logger.info("sampler: infinite")
113
+ if size > 0:
114
+ raise ValueError("sampler size > 0 is invalid")
115
+ return InfiniteSampler(
116
+ sample_count=sample_count,
117
+ shuffle=shuffle,
118
+ seed=seed,
119
+ advance=advance,
120
+ )
121
+ elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW):
122
+ logger.info("sampler: sharded infinite")
123
+ if size > 0:
124
+ raise ValueError("sampler size > 0 is invalid")
125
+ # TODO: Remove support for old shuffling
126
+ use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW
127
+ return ShardedInfiniteSampler(
128
+ sample_count=sample_count,
129
+ shuffle=shuffle,
130
+ seed=seed,
131
+ advance=advance,
132
+ use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice,
133
+ )
134
+ elif type == SamplerType.EPOCH:
135
+ logger.info("sampler: epoch")
136
+ if advance > 0:
137
+ raise NotImplementedError("sampler advance > 0 is not supported")
138
+ size = size if size > 0 else sample_count
139
+ logger.info(f"# of samples / epoch: {size:,d}")
140
+ return EpochSampler(
141
+ size=size,
142
+ sample_count=sample_count,
143
+ shuffle=shuffle,
144
+ seed=seed,
145
+ )
146
+ elif type == SamplerType.DISTRIBUTED:
147
+ logger.info("sampler: distributed")
148
+ if size > 0:
149
+ raise ValueError("sampler size > 0 is invalid")
150
+ if advance > 0:
151
+ raise ValueError("sampler advance > 0 is invalid")
152
+ return torch.utils.data.DistributedSampler(
153
+ dataset=dataset,
154
+ shuffle=shuffle,
155
+ seed=seed,
156
+ drop_last=False,
157
+ )
158
+
159
+ logger.info("sampler: none")
160
+ return None
161
+
162
+
163
+ T = TypeVar("T")
164
+
165
+
166
+ def make_data_loader(
167
+ *,
168
+ dataset,
169
+ batch_size: int,
170
+ num_workers: int,
171
+ shuffle: bool = True,
172
+ seed: int = 0,
173
+ sampler_type: Optional[SamplerType] = SamplerType.INFINITE,
174
+ sampler_size: int = -1,
175
+ sampler_advance: int = 0,
176
+ drop_last: bool = True,
177
+ persistent_workers: bool = False,
178
+ collate_fn: Optional[Callable[[List[T]], Any]] = None,
179
+ ):
180
+ """
181
+ Creates a data loader with the specified parameters.
182
+
183
+ Args:
184
+ dataset: A dataset (third party, LaViDa or WebDataset).
185
+ batch_size: The size of batches to generate.
186
+ num_workers: The number of workers to use.
187
+ shuffle: Whether to shuffle samples.
188
+ seed: The random seed to use.
189
+ sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None.
190
+ sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset.
191
+ sampler_advance: How many samples to skip (when applicable).
192
+ drop_last: Whether the last non-full batch of data should be dropped.
193
+ persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once.
194
+ collate_fn: Function that performs batch collation
195
+ """
196
+
197
+ sampler = _make_sampler(
198
+ dataset=dataset,
199
+ type=sampler_type,
200
+ shuffle=shuffle,
201
+ seed=seed,
202
+ size=sampler_size,
203
+ advance=sampler_advance,
204
+ )
205
+
206
+ logger.info("using PyTorch data loader")
207
+ data_loader = torch.utils.data.DataLoader(
208
+ dataset,
209
+ sampler=sampler,
210
+ batch_size=batch_size,
211
+ num_workers=num_workers,
212
+ pin_memory=True,
213
+ drop_last=drop_last,
214
+ persistent_workers=persistent_workers,
215
+ collate_fn=collate_fn,
216
+ )
217
+
218
+ try:
219
+ logger.info(f"# of batches: {len(data_loader):,d}")
220
+ except TypeError: # data loader has no length
221
+ logger.info("infinite data loader")
222
+ return data_loader
dinov2/data/masking.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import math
8
+ import numpy as np
9
+
10
+
11
+ class MaskingGenerator:
12
+ def __init__(
13
+ self,
14
+ input_size,
15
+ num_masking_patches=None,
16
+ min_num_patches=4,
17
+ max_num_patches=None,
18
+ min_aspect=0.3,
19
+ max_aspect=None,
20
+ ):
21
+ if not isinstance(input_size, tuple):
22
+ input_size = (input_size,) * 2
23
+ self.height, self.width = input_size
24
+
25
+ self.num_patches = self.height * self.width
26
+ self.num_masking_patches = num_masking_patches
27
+
28
+ self.min_num_patches = min_num_patches
29
+ self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches
30
+
31
+ max_aspect = max_aspect or 1 / min_aspect
32
+ self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
33
+
34
+ def __repr__(self):
35
+ repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
36
+ self.height,
37
+ self.width,
38
+ self.min_num_patches,
39
+ self.max_num_patches,
40
+ self.num_masking_patches,
41
+ self.log_aspect_ratio[0],
42
+ self.log_aspect_ratio[1],
43
+ )
44
+ return repr_str
45
+
46
+ def get_shape(self):
47
+ return self.height, self.width
48
+
49
+ def _mask(self, mask, max_mask_patches):
50
+ delta = 0
51
+ for _ in range(10):
52
+ target_area = random.uniform(self.min_num_patches, max_mask_patches)
53
+ aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
54
+ h = int(round(math.sqrt(target_area * aspect_ratio)))
55
+ w = int(round(math.sqrt(target_area / aspect_ratio)))
56
+ if w < self.width and h < self.height:
57
+ top = random.randint(0, self.height - h)
58
+ left = random.randint(0, self.width - w)
59
+
60
+ num_masked = mask[top : top + h, left : left + w].sum()
61
+ # Overlap
62
+ if 0 < h * w - num_masked <= max_mask_patches:
63
+ for i in range(top, top + h):
64
+ for j in range(left, left + w):
65
+ if mask[i, j] == 0:
66
+ mask[i, j] = 1
67
+ delta += 1
68
+
69
+ if delta > 0:
70
+ break
71
+ return delta
72
+
73
+ def __call__(self, num_masking_patches=0):
74
+ mask = np.zeros(shape=self.get_shape(), dtype=bool)
75
+ mask_count = 0
76
+ while mask_count < num_masking_patches:
77
+ max_mask_patches = num_masking_patches - mask_count
78
+ max_mask_patches = min(max_mask_patches, self.max_num_patches)
79
+
80
+ delta = self._mask(mask, max_mask_patches)
81
+ if delta == 0:
82
+ break
83
+ else:
84
+ mask_count += delta
85
+
86
+ return mask
dinov2/data/samplers.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ from typing import Any, Optional
8
+ import warnings
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.utils.data.sampler import Sampler
13
+
14
+ import dinov2.distributed as distributed
15
+
16
+
17
+ class EpochSampler(Sampler):
18
+ def __init__(
19
+ self,
20
+ *,
21
+ size: int,
22
+ sample_count: int,
23
+ shuffle: bool = False,
24
+ seed: int = 0,
25
+ start: Optional[int] = None,
26
+ step: Optional[int] = None,
27
+ ):
28
+ self._size = size
29
+ self._sample_count = sample_count
30
+ self._shuffle = shuffle
31
+ self._seed = seed
32
+ self._start = distributed.get_global_rank() if start is None else start
33
+ self._step = distributed.get_global_size() if step is None else step
34
+ self._epoch = 0
35
+
36
+ def __iter__(self):
37
+ count = (self._size + self._sample_count - 1) // self._sample_count
38
+ tiled_indices = np.tile(np.arange(self._sample_count), count)
39
+ if self._shuffle:
40
+ seed = self._seed * self._epoch if self._seed != 0 else self._epoch
41
+ rng = np.random.default_rng(seed)
42
+ iterable = rng.choice(tiled_indices, self._size, replace=False)
43
+ else:
44
+ iterable = tiled_indices[: self._size]
45
+
46
+ yield from itertools.islice(iterable, self._start, None, self._step)
47
+
48
+ def __len__(self):
49
+ return (self._size - self._start + self._step - 1) // self._step
50
+
51
+ def set_epoch(self, epoch):
52
+ self._epoch = epoch
53
+
54
+
55
+ def _get_numpy_dtype(size: int) -> Any:
56
+ return np.int32 if size <= 2**31 else np.int64
57
+
58
+
59
+ def _get_torch_dtype(size: int) -> Any:
60
+ return torch.int32 if size <= 2**31 else torch.int64
61
+
62
+
63
+ def _generate_randperm_indices(*, size: int, generator: torch.Generator):
64
+ """Generate the indices of a random permutation."""
65
+ dtype = _get_torch_dtype(size)
66
+ # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
67
+ perm = torch.arange(size, dtype=dtype)
68
+ for i in range(size):
69
+ j = torch.randint(i, size, size=(1,), generator=generator).item()
70
+
71
+ # Always swap even if no-op
72
+ value = perm[j].item()
73
+ perm[j] = perm[i].item()
74
+ perm[i] = value
75
+ yield value
76
+
77
+
78
+ class InfiniteSampler(Sampler):
79
+ def __init__(
80
+ self,
81
+ *,
82
+ sample_count: int,
83
+ shuffle: bool = False,
84
+ seed: int = 0,
85
+ start: Optional[int] = None,
86
+ step: Optional[int] = None,
87
+ advance: int = 0,
88
+ ):
89
+ self._sample_count = sample_count
90
+ self._seed = seed
91
+ self._shuffle = shuffle
92
+ self._start = distributed.get_global_rank() if start is None else start
93
+ self._step = distributed.get_global_size() if step is None else step
94
+ self._advance = advance
95
+
96
+ def __iter__(self):
97
+ if self._shuffle:
98
+ iterator = self._shuffled_iterator()
99
+ else:
100
+ iterator = self._iterator()
101
+
102
+ yield from itertools.islice(iterator, self._advance, None)
103
+
104
+ def _iterator(self):
105
+ assert not self._shuffle
106
+
107
+ while True:
108
+ iterable = range(self._sample_count)
109
+ yield from itertools.islice(iterable, self._start, None, self._step)
110
+
111
+ def _shuffled_iterator(self):
112
+ assert self._shuffle
113
+
114
+ # Instantiate a generator here (rather than in the ctor) to keep the class
115
+ # picklable (requirement of mp.spawn)
116
+ generator = torch.Generator().manual_seed(self._seed)
117
+
118
+ while True:
119
+ iterable = _generate_randperm_indices(size=self._sample_count, generator=generator)
120
+ yield from itertools.islice(iterable, self._start, None, self._step)
121
+
122
+
123
+ # The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
124
+ # but avoids a full in-place random permutation generation.
125
+ def _shuffle_tensor_slice(
126
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
127
+ ) -> np.ndarray:
128
+ stop = len(tensor)
129
+ count = stop // step
130
+ drop_count = stop - step * count
131
+ if drop_count:
132
+ warnings.warn(f"# of dropped samples: {drop_count}")
133
+
134
+ dtype = _get_numpy_dtype(stop)
135
+ result = np.empty(count, dtype=dtype)
136
+
137
+ for i in range(count):
138
+ j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0
139
+
140
+ result[i] = result[j]
141
+ result[j] = tensor[start + i * step].item()
142
+
143
+ return result
144
+
145
+
146
+ def _new_shuffle_tensor_slice(
147
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
148
+ ) -> np.ndarray:
149
+ stop = len(tensor)
150
+ count = stop // step
151
+ dtype = torch.int64 # Needed for using randperm result as indices
152
+ count = stop // step
153
+ drop_count = stop - step * count
154
+ if drop_count:
155
+ warnings.warn(f"# of dropped samples: {drop_count}")
156
+ indices = torch.randperm(count, dtype=dtype, generator=generator)
157
+ return tensor[start::step][indices].numpy()
158
+
159
+
160
+ def _make_seed(seed: int, start: int, iter_count: int) -> int:
161
+ # NOTE: Tried a few variants (including iter_count << 32), this one worked best.
162
+ return seed + start + (iter_count << 24)
163
+
164
+
165
+ class ShardedInfiniteSampler(Sampler):
166
+ def __init__(
167
+ self,
168
+ *,
169
+ sample_count: int,
170
+ shuffle: bool = False,
171
+ seed: int = 0,
172
+ start: Optional[int] = None,
173
+ step: Optional[int] = None,
174
+ advance: int = 0,
175
+ use_new_shuffle_tensor_slice: bool = False,
176
+ ):
177
+ self._sample_count = sample_count
178
+ self._seed = seed
179
+ self._shuffle = shuffle
180
+ self._start = distributed.get_global_rank() if start is None else start
181
+ self._step = distributed.get_global_size() if step is None else step
182
+ self._advance = advance
183
+ self._iter_count = 0
184
+ self._shuffle_tensor_slice_fn = (
185
+ _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice
186
+ )
187
+
188
+ def __iter__(self):
189
+ iter_count = self._advance // self._sample_count
190
+ if iter_count > 0:
191
+ self._advance -= iter_count * self._sample_count
192
+ self._iter_count += iter_count
193
+
194
+ if self._shuffle:
195
+ iterator = self._shuffled_iterator()
196
+ else:
197
+ iterator = self._iterator()
198
+
199
+ yield from itertools.islice(iterator, self._advance, None)
200
+
201
+ def _iterator(self):
202
+ assert not self._shuffle
203
+
204
+ while True:
205
+ iterable = range(self._sample_count)
206
+ yield from itertools.islice(iterable, self._start, None, self._step)
207
+
208
+ def _shuffled_iterator(self):
209
+ assert self._shuffle
210
+
211
+ # Instantiate a generator here (rather than in the ctor) to be keep the class
212
+ # picklable (requirement of mp.spawn)
213
+ generator = torch.Generator()
214
+
215
+ # Always shuffle everything first
216
+ generator.manual_seed(self._seed)
217
+ dtype = _get_torch_dtype(self._sample_count)
218
+ perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
219
+
220
+ while True:
221
+ # Re-seed on each iteration to allow skipping whole permutations
222
+ seed = _make_seed(self._seed, self._start, self._iter_count)
223
+ generator.manual_seed(seed)
224
+
225
+ iterable = self._shuffle_tensor_slice_fn(
226
+ tensor=perm, start=self._start, step=self._step, generator=generator
227
+ )
228
+ yield from iterable
229
+ self._iter_count += 1
dinov2/data/transforms.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from typing import Sequence
7
+
8
+ import torch
9
+ from torchvision import transforms
10
+
11
+
12
+ class GaussianBlur(transforms.RandomApply):
13
+ """
14
+ Apply Gaussian Blur to the PIL image.
15
+ """
16
+
17
+ def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0):
18
+ # NOTE: torchvision is applying 1 - probability to return the original image
19
+ keep_p = 1 - p
20
+ transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max))
21
+ super().__init__(transforms=[transform], p=keep_p)
22
+
23
+
24
+ class MaybeToTensor(transforms.ToTensor):
25
+ """
26
+ Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor.
27
+ """
28
+
29
+ def __call__(self, pic):
30
+ """
31
+ Args:
32
+ pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor.
33
+ Returns:
34
+ Tensor: Converted image.
35
+ """
36
+ if isinstance(pic, torch.Tensor):
37
+ return pic
38
+ return super().__call__(pic)
39
+
40
+
41
+ # Use timm's names
42
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
43
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
44
+
45
+
46
+ def make_normalize_transform(
47
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
48
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
49
+ ) -> transforms.Normalize:
50
+ return transforms.Normalize(mean=mean, std=std)
51
+
52
+
53
+ # This roughly matches torchvision's preset for classification training:
54
+ # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44
55
+ def make_classification_train_transform(
56
+ *,
57
+ crop_size: int = 224,
58
+ interpolation=transforms.InterpolationMode.BICUBIC,
59
+ hflip_prob: float = 0.5,
60
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
61
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
62
+ ):
63
+ transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
64
+ if hflip_prob > 0.0:
65
+ transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob))
66
+ transforms_list.extend(
67
+ [
68
+ MaybeToTensor(),
69
+ make_normalize_transform(mean=mean, std=std),
70
+ ]
71
+ )
72
+ return transforms.Compose(transforms_list)
73
+
74
+
75
+ # This matches (roughly) torchvision's preset for classification evaluation:
76
+ # https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69
77
+ def make_classification_eval_transform(
78
+ *,
79
+ resize_size: int = 256,
80
+ interpolation=transforms.InterpolationMode.BICUBIC,
81
+ crop_size: int = 224,
82
+ mean: Sequence[float] = IMAGENET_DEFAULT_MEAN,
83
+ std: Sequence[float] = IMAGENET_DEFAULT_STD,
84
+ ) -> transforms.Compose:
85
+ transforms_list = [
86
+ transforms.Resize(resize_size, interpolation=interpolation),
87
+ transforms.CenterCrop(crop_size),
88
+ MaybeToTensor(),
89
+ make_normalize_transform(mean=mean, std=std),
90
+ ]
91
+ return transforms.Compose(transforms_list)
dinov2/distributed/__init__.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import random
8
+ import re
9
+ import socket
10
+ from typing import Dict, List
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+
15
+ _LOCAL_RANK = -1
16
+ _LOCAL_WORLD_SIZE = -1
17
+
18
+
19
+ def is_enabled() -> bool:
20
+ """
21
+ Returns:
22
+ True if distributed training is enabled
23
+ """
24
+ return dist.is_available() and dist.is_initialized()
25
+
26
+
27
+ def get_global_size() -> int:
28
+ """
29
+ Returns:
30
+ The number of processes in the process group
31
+ """
32
+ return dist.get_world_size() if is_enabled() else 1
33
+
34
+
35
+ def get_global_rank() -> int:
36
+ """
37
+ Returns:
38
+ The rank of the current process within the global process group.
39
+ """
40
+ return dist.get_rank() if is_enabled() else 0
41
+
42
+
43
+ def get_local_rank() -> int:
44
+ """
45
+ Returns:
46
+ The rank of the current process within the local (per-machine) process group.
47
+ """
48
+ if not is_enabled():
49
+ return 0
50
+ assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
51
+ return _LOCAL_RANK
52
+
53
+
54
+ def get_local_size() -> int:
55
+ """
56
+ Returns:
57
+ The size of the per-machine process group,
58
+ i.e. the number of processes per machine.
59
+ """
60
+ if not is_enabled():
61
+ return 1
62
+ assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE
63
+ return _LOCAL_WORLD_SIZE
64
+
65
+
66
+ def is_main_process() -> bool:
67
+ """
68
+ Returns:
69
+ True if the current process is the main one.
70
+ """
71
+ return get_global_rank() == 0
72
+
73
+
74
+ def _restrict_print_to_main_process() -> None:
75
+ """
76
+ This function disables printing when not in the main process
77
+ """
78
+ import builtins as __builtin__
79
+
80
+ builtin_print = __builtin__.print
81
+
82
+ def print(*args, **kwargs):
83
+ force = kwargs.pop("force", False)
84
+ if is_main_process() or force:
85
+ builtin_print(*args, **kwargs)
86
+
87
+ __builtin__.print = print
88
+
89
+
90
+ def _get_master_port(seed: int = 0) -> int:
91
+ MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000)
92
+
93
+ master_port_str = os.environ.get("MASTER_PORT")
94
+ if master_port_str is None:
95
+ rng = random.Random(seed)
96
+ return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT)
97
+
98
+ return int(master_port_str)
99
+
100
+
101
+ def _get_available_port() -> int:
102
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
103
+ # A "" host address means INADDR_ANY i.e. binding to all interfaces.
104
+ # Note this is not compatible with IPv6.
105
+ s.bind(("", 0))
106
+ port = s.getsockname()[1]
107
+ return port
108
+
109
+
110
+ _TORCH_DISTRIBUTED_ENV_VARS = (
111
+ "MASTER_ADDR",
112
+ "MASTER_PORT",
113
+ "RANK",
114
+ "WORLD_SIZE",
115
+ "LOCAL_RANK",
116
+ "LOCAL_WORLD_SIZE",
117
+ )
118
+
119
+
120
+ def _collect_env_vars() -> Dict[str, str]:
121
+ return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ}
122
+
123
+
124
+ def _is_slurm_job_process() -> bool:
125
+ return "SLURM_JOB_ID" in os.environ
126
+
127
+
128
+ def _parse_slurm_node_list(s: str) -> List[str]:
129
+ nodes = []
130
+ # Extract "hostname", "hostname[1-2,3,4-5]," substrings
131
+ p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
132
+ for m in p.finditer(s):
133
+ prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
134
+ for suffix in suffixes.split(","):
135
+ span = suffix.split("-")
136
+ if len(span) == 1:
137
+ nodes.append(prefix + suffix)
138
+ else:
139
+ width = len(span[0])
140
+ start, end = int(span[0]), int(span[1]) + 1
141
+ nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
142
+ return nodes
143
+
144
+
145
+ def _check_env_variable(key: str, new_value: str):
146
+ # Only check for difference with preset environment variables
147
+ if key in os.environ and os.environ[key] != new_value:
148
+ raise RuntimeError(f"Cannot export environment variables as {key} is already set")
149
+
150
+
151
+ class _TorchDistributedEnvironment:
152
+ def __init__(self):
153
+ self.master_addr = "127.0.0.1"
154
+ self.master_port = 0
155
+ self.rank = -1
156
+ self.world_size = -1
157
+ self.local_rank = -1
158
+ self.local_world_size = -1
159
+
160
+ if _is_slurm_job_process():
161
+ return self._set_from_slurm_env()
162
+
163
+ env_vars = _collect_env_vars()
164
+ if not env_vars:
165
+ # Environment is not set
166
+ pass
167
+ elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS):
168
+ # Environment is fully set
169
+ return self._set_from_preset_env()
170
+ else:
171
+ # Environment is partially set
172
+ collected_env_vars = ", ".join(env_vars.keys())
173
+ raise RuntimeError(f"Partially set environment: {collected_env_vars}")
174
+
175
+ if torch.cuda.device_count() > 0:
176
+ return self._set_from_local()
177
+
178
+ raise RuntimeError("Can't initialize PyTorch distributed environment")
179
+
180
+ # Slurm job created with sbatch, submitit, etc...
181
+ def _set_from_slurm_env(self):
182
+ # logger.info("Initialization from Slurm environment")
183
+ job_id = int(os.environ["SLURM_JOB_ID"])
184
+ node_count = int(os.environ["SLURM_JOB_NUM_NODES"])
185
+ nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])
186
+ assert len(nodes) == node_count
187
+
188
+ self.master_addr = nodes[0]
189
+ self.master_port = _get_master_port(seed=job_id)
190
+ self.rank = int(os.environ["SLURM_PROCID"])
191
+ self.world_size = int(os.environ["SLURM_NTASKS"])
192
+ assert self.rank < self.world_size
193
+ self.local_rank = int(os.environ["SLURM_LOCALID"])
194
+ self.local_world_size = self.world_size // node_count
195
+ assert self.local_rank < self.local_world_size
196
+
197
+ # Single node job with preset environment (i.e. torchrun)
198
+ def _set_from_preset_env(self):
199
+ # logger.info("Initialization from preset environment")
200
+ self.master_addr = os.environ["MASTER_ADDR"]
201
+ self.master_port = os.environ["MASTER_PORT"]
202
+ self.rank = int(os.environ["RANK"])
203
+ self.world_size = int(os.environ["WORLD_SIZE"])
204
+ assert self.rank < self.world_size
205
+ self.local_rank = int(os.environ["LOCAL_RANK"])
206
+ self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
207
+ assert self.local_rank < self.local_world_size
208
+
209
+ # Single node and GPU job (i.e. local script run)
210
+ def _set_from_local(self):
211
+ # logger.info("Initialization from local")
212
+ self.master_addr = "127.0.0.1"
213
+ self.master_port = _get_available_port()
214
+ self.rank = 0
215
+ self.world_size = 1
216
+ self.local_rank = 0
217
+ self.local_world_size = 1
218
+
219
+ def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment":
220
+ # See the "Environment variable initialization" section from
221
+ # https://pytorch.org/docs/stable/distributed.html for the complete list of
222
+ # environment variables required for the env:// initialization method.
223
+ env_vars = {
224
+ "MASTER_ADDR": self.master_addr,
225
+ "MASTER_PORT": str(self.master_port),
226
+ "RANK": str(self.rank),
227
+ "WORLD_SIZE": str(self.world_size),
228
+ "LOCAL_RANK": str(self.local_rank),
229
+ "LOCAL_WORLD_SIZE": str(self.local_world_size),
230
+ }
231
+ if not overwrite:
232
+ for k, v in env_vars.items():
233
+ _check_env_variable(k, v)
234
+
235
+ os.environ.update(env_vars)
236
+ return self
237
+
238
+
239
+ def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False):
240
+ """Enable distributed mode
241
+
242
+ Args:
243
+ set_cuda_current_device: If True, call torch.cuda.set_device() to set the
244
+ current PyTorch CUDA device to the one matching the local rank.
245
+ overwrite: If True, overwrites already set variables. Else fails.
246
+ """
247
+
248
+ global _LOCAL_RANK, _LOCAL_WORLD_SIZE
249
+ if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0:
250
+ raise RuntimeError("Distributed mode has already been enabled")
251
+ torch_env = _TorchDistributedEnvironment()
252
+ torch_env.export(overwrite=overwrite)
253
+
254
+ if set_cuda_current_device:
255
+ torch.cuda.set_device(torch_env.local_rank)
256
+
257
+ if allow_nccl_timeout:
258
+ # This allows to use torch distributed timeout in a NCCL backend
259
+ key, value = "NCCL_ASYNC_ERROR_HANDLING", "1"
260
+ if not overwrite:
261
+ _check_env_variable(key, value)
262
+ os.environ[key] = value
263
+
264
+ dist.init_process_group(backend="nccl")
265
+ dist.barrier()
266
+
267
+ # Finalize setup
268
+ _LOCAL_RANK = torch_env.local_rank
269
+ _LOCAL_WORLD_SIZE = torch_env.local_world_size
270
+ _restrict_print_to_main_process()
dinov2/eval/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
dinov2/eval/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (171 Bytes). View file
 
dinov2/eval/depth/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
dinov2/eval/depth/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .backbones import * # noqa: F403
7
+ from .builder import BACKBONES, DEPTHER, HEADS, LOSSES, build_backbone, build_depther, build_head, build_loss
8
+ from .decode_heads import * # noqa: F403
9
+ from .depther import * # noqa: F403
10
+ from .losses import * # noqa: F403
dinov2/eval/depth/models/backbones/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .vision_transformer import DinoVisionTransformer
dinov2/eval/depth/models/backbones/vision_transformer.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from mmcv.runner import BaseModule
7
+
8
+ from ..builder import BACKBONES
9
+
10
+
11
+ @BACKBONES.register_module()
12
+ class DinoVisionTransformer(BaseModule):
13
+ """Vision Transformer."""
14
+
15
+ def __init__(self, *args, **kwargs):
16
+ super().__init__()
dinov2/eval/depth/models/builder.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import warnings
7
+
8
+ from mmcv.cnn import MODELS as MMCV_MODELS
9
+ from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
10
+ from mmcv.utils import Registry
11
+
12
+ MODELS = Registry("models", parent=MMCV_MODELS)
13
+ ATTENTION = Registry("attention", parent=MMCV_ATTENTION)
14
+
15
+
16
+ BACKBONES = MODELS
17
+ NECKS = MODELS
18
+ HEADS = MODELS
19
+ LOSSES = MODELS
20
+ DEPTHER = MODELS
21
+
22
+
23
+ def build_backbone(cfg):
24
+ """Build backbone."""
25
+ return BACKBONES.build(cfg)
26
+
27
+
28
+ def build_neck(cfg):
29
+ """Build neck."""
30
+ return NECKS.build(cfg)
31
+
32
+
33
+ def build_head(cfg):
34
+ """Build head."""
35
+ return HEADS.build(cfg)
36
+
37
+
38
+ def build_loss(cfg):
39
+ """Build loss."""
40
+ return LOSSES.build(cfg)
41
+
42
+
43
+ def build_depther(cfg, train_cfg=None, test_cfg=None):
44
+ """Build depther."""
45
+ if train_cfg is not None or test_cfg is not None:
46
+ warnings.warn("train_cfg and test_cfg is deprecated, " "please specify them in model", UserWarning)
47
+ assert cfg.get("train_cfg") is None or train_cfg is None, "train_cfg specified in both outer field and model field "
48
+ assert cfg.get("test_cfg") is None or test_cfg is None, "test_cfg specified in both outer field and model field "
49
+ return DEPTHER.build(cfg, default_args=dict(train_cfg=train_cfg, test_cfg=test_cfg))
dinov2/eval/depth/models/decode_heads/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dpt_head import DPTHead
7
+ from .linear_head import BNHead
dinov2/eval/depth/models/decode_heads/decode_head.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import copy
7
+ from abc import ABCMeta, abstractmethod
8
+
9
+ import mmcv
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ from mmcv.runner import BaseModule, auto_fp16, force_fp32
14
+
15
+ from ...ops import resize
16
+ from ..builder import build_loss
17
+
18
+
19
+ class DepthBaseDecodeHead(BaseModule, metaclass=ABCMeta):
20
+ """Base class for BaseDecodeHead.
21
+
22
+ Args:
23
+ in_channels (List): Input channels.
24
+ channels (int): Channels after modules, before conv_depth.
25
+ conv_cfg (dict|None): Config of conv layers. Default: None.
26
+ act_cfg (dict): Config of activation layers.
27
+ Default: dict(type='ReLU')
28
+ loss_decode (dict): Config of decode loss.
29
+ Default: dict(type='SigLoss').
30
+ sampler (dict|None): The config of depth map sampler.
31
+ Default: None.
32
+ align_corners (bool): align_corners argument of F.interpolate.
33
+ Default: False.
34
+ min_depth (int): Min depth in dataset setting.
35
+ Default: 1e-3.
36
+ max_depth (int): Max depth in dataset setting.
37
+ Default: None.
38
+ norm_cfg (dict|None): Config of norm layers.
39
+ Default: None.
40
+ classify (bool): Whether predict depth in a cls.-reg. manner.
41
+ Default: False.
42
+ n_bins (int): The number of bins used in cls. step.
43
+ Default: 256.
44
+ bins_strategy (str): The discrete strategy used in cls. step.
45
+ Default: 'UD'.
46
+ norm_strategy (str): The norm strategy on cls. probability
47
+ distribution. Default: 'linear'
48
+ scale_up (str): Whether predict depth in a scale-up manner.
49
+ Default: False.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ in_channels,
55
+ channels=96,
56
+ conv_cfg=None,
57
+ act_cfg=dict(type="ReLU"),
58
+ loss_decode=dict(type="SigLoss", valid_mask=True, loss_weight=10),
59
+ sampler=None,
60
+ align_corners=False,
61
+ min_depth=1e-3,
62
+ max_depth=None,
63
+ norm_cfg=None,
64
+ classify=False,
65
+ n_bins=256,
66
+ bins_strategy="UD",
67
+ norm_strategy="linear",
68
+ scale_up=False,
69
+ ):
70
+ super(DepthBaseDecodeHead, self).__init__()
71
+
72
+ self.in_channels = in_channels
73
+ self.channels = channels
74
+ self.conv_cfg = conv_cfg
75
+ self.act_cfg = act_cfg
76
+ if isinstance(loss_decode, dict):
77
+ self.loss_decode = build_loss(loss_decode)
78
+ elif isinstance(loss_decode, (list, tuple)):
79
+ self.loss_decode = nn.ModuleList()
80
+ for loss in loss_decode:
81
+ self.loss_decode.append(build_loss(loss))
82
+ self.align_corners = align_corners
83
+ self.min_depth = min_depth
84
+ self.max_depth = max_depth
85
+ self.norm_cfg = norm_cfg
86
+ self.classify = classify
87
+ self.n_bins = n_bins
88
+ self.scale_up = scale_up
89
+
90
+ if self.classify:
91
+ assert bins_strategy in ["UD", "SID"], "Support bins_strategy: UD, SID"
92
+ assert norm_strategy in ["linear", "softmax", "sigmoid"], "Support norm_strategy: linear, softmax, sigmoid"
93
+
94
+ self.bins_strategy = bins_strategy
95
+ self.norm_strategy = norm_strategy
96
+ self.softmax = nn.Softmax(dim=1)
97
+ self.conv_depth = nn.Conv2d(channels, n_bins, kernel_size=3, padding=1, stride=1)
98
+ else:
99
+ self.conv_depth = nn.Conv2d(channels, 1, kernel_size=3, padding=1, stride=1)
100
+
101
+ self.fp16_enabled = False
102
+ self.relu = nn.ReLU()
103
+ self.sigmoid = nn.Sigmoid()
104
+
105
+ def extra_repr(self):
106
+ """Extra repr."""
107
+ s = f"align_corners={self.align_corners}"
108
+ return s
109
+
110
+ @auto_fp16()
111
+ @abstractmethod
112
+ def forward(self, inputs, img_metas):
113
+ """Placeholder of forward function."""
114
+ pass
115
+
116
+ def forward_train(self, img, inputs, img_metas, depth_gt, train_cfg):
117
+ """Forward function for training.
118
+ Args:
119
+ inputs (list[Tensor]): List of multi-level img features.
120
+ img_metas (list[dict]): List of image info dict where each dict
121
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
122
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
123
+ For details on the values of these keys see
124
+ `depth/datasets/pipelines/formatting.py:Collect`.
125
+ depth_gt (Tensor): GT depth
126
+ train_cfg (dict): The training config.
127
+
128
+ Returns:
129
+ dict[str, Tensor]: a dictionary of loss components
130
+ """
131
+ depth_pred = self.forward(inputs, img_metas)
132
+ losses = self.losses(depth_pred, depth_gt)
133
+
134
+ log_imgs = self.log_images(img[0], depth_pred[0], depth_gt[0], img_metas[0])
135
+ losses.update(**log_imgs)
136
+
137
+ return losses
138
+
139
+ def forward_test(self, inputs, img_metas, test_cfg):
140
+ """Forward function for testing.
141
+ Args:
142
+ inputs (list[Tensor]): List of multi-level img features.
143
+ img_metas (list[dict]): List of image info dict where each dict
144
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
145
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
146
+ For details on the values of these keys see
147
+ `depth/datasets/pipelines/formatting.py:Collect`.
148
+ test_cfg (dict): The testing config.
149
+
150
+ Returns:
151
+ Tensor: Output depth map.
152
+ """
153
+ return self.forward(inputs, img_metas)
154
+
155
+ def depth_pred(self, feat):
156
+ """Prediction each pixel."""
157
+ if self.classify:
158
+ logit = self.conv_depth(feat)
159
+
160
+ if self.bins_strategy == "UD":
161
+ bins = torch.linspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
162
+ elif self.bins_strategy == "SID":
163
+ bins = torch.logspace(self.min_depth, self.max_depth, self.n_bins, device=feat.device)
164
+
165
+ # following Adabins, default linear
166
+ if self.norm_strategy == "linear":
167
+ logit = torch.relu(logit)
168
+ eps = 0.1
169
+ logit = logit + eps
170
+ logit = logit / logit.sum(dim=1, keepdim=True)
171
+ elif self.norm_strategy == "softmax":
172
+ logit = torch.softmax(logit, dim=1)
173
+ elif self.norm_strategy == "sigmoid":
174
+ logit = torch.sigmoid(logit)
175
+ logit = logit / logit.sum(dim=1, keepdim=True)
176
+
177
+ output = torch.einsum("ikmn,k->imn", [logit, bins]).unsqueeze(dim=1)
178
+
179
+ else:
180
+ if self.scale_up:
181
+ output = self.sigmoid(self.conv_depth(feat)) * self.max_depth
182
+ else:
183
+ output = self.relu(self.conv_depth(feat)) + self.min_depth
184
+ return output
185
+
186
+ @force_fp32(apply_to=("depth_pred",))
187
+ def losses(self, depth_pred, depth_gt):
188
+ """Compute depth loss."""
189
+ loss = dict()
190
+ depth_pred = resize(
191
+ input=depth_pred, size=depth_gt.shape[2:], mode="bilinear", align_corners=self.align_corners, warning=False
192
+ )
193
+ if not isinstance(self.loss_decode, nn.ModuleList):
194
+ losses_decode = [self.loss_decode]
195
+ else:
196
+ losses_decode = self.loss_decode
197
+ for loss_decode in losses_decode:
198
+ if loss_decode.loss_name not in loss:
199
+ loss[loss_decode.loss_name] = loss_decode(depth_pred, depth_gt)
200
+ else:
201
+ loss[loss_decode.loss_name] += loss_decode(depth_pred, depth_gt)
202
+ return loss
203
+
204
+ def log_images(self, img_path, depth_pred, depth_gt, img_meta):
205
+ show_img = copy.deepcopy(img_path.detach().cpu().permute(1, 2, 0))
206
+ show_img = show_img.numpy().astype(np.float32)
207
+ show_img = mmcv.imdenormalize(
208
+ show_img,
209
+ img_meta["img_norm_cfg"]["mean"],
210
+ img_meta["img_norm_cfg"]["std"],
211
+ img_meta["img_norm_cfg"]["to_rgb"],
212
+ )
213
+ show_img = np.clip(show_img, 0, 255)
214
+ show_img = show_img.astype(np.uint8)
215
+ show_img = show_img[:, :, ::-1]
216
+ show_img = show_img.transpose(0, 2, 1)
217
+ show_img = show_img.transpose(1, 0, 2)
218
+
219
+ depth_pred = depth_pred / torch.max(depth_pred)
220
+ depth_gt = depth_gt / torch.max(depth_gt)
221
+
222
+ depth_pred_color = copy.deepcopy(depth_pred.detach().cpu())
223
+ depth_gt_color = copy.deepcopy(depth_gt.detach().cpu())
224
+
225
+ return {"img_rgb": show_img, "img_depth_pred": depth_pred_color, "img_depth_gt": depth_gt_color}
dinov2/eval/depth/models/decode_heads/dpt_head.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from mmcv.cnn import ConvModule, Linear, build_activation_layer
11
+ from mmcv.runner import BaseModule
12
+
13
+ from ...ops import resize
14
+ from ..builder import HEADS
15
+ from .decode_head import DepthBaseDecodeHead
16
+
17
+
18
+ class Interpolate(nn.Module):
19
+ def __init__(self, scale_factor, mode, align_corners=False):
20
+ super(Interpolate, self).__init__()
21
+ self.interp = nn.functional.interpolate
22
+ self.scale_factor = scale_factor
23
+ self.mode = mode
24
+ self.align_corners = align_corners
25
+
26
+ def forward(self, x):
27
+ x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners)
28
+ return x
29
+
30
+
31
+ class HeadDepth(nn.Module):
32
+ def __init__(self, features):
33
+ super(HeadDepth, self).__init__()
34
+ self.head = nn.Sequential(
35
+ nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
36
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
37
+ nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
38
+ nn.ReLU(),
39
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
40
+ )
41
+
42
+ def forward(self, x):
43
+ x = self.head(x)
44
+ return x
45
+
46
+
47
+ class ReassembleBlocks(BaseModule):
48
+ """ViTPostProcessBlock, process cls_token in ViT backbone output and
49
+ rearrange the feature vector to feature map.
50
+ Args:
51
+ in_channels (int): ViT feature channels. Default: 768.
52
+ out_channels (List): output channels of each stage.
53
+ Default: [96, 192, 384, 768].
54
+ readout_type (str): Type of readout operation. Default: 'ignore'.
55
+ patch_size (int): The patch size. Default: 16.
56
+ init_cfg (dict, optional): Initialization config dict. Default: None.
57
+ """
58
+
59
+ def __init__(
60
+ self, in_channels=768, out_channels=[96, 192, 384, 768], readout_type="ignore", patch_size=16, init_cfg=None
61
+ ):
62
+ super(ReassembleBlocks, self).__init__(init_cfg)
63
+
64
+ assert readout_type in ["ignore", "add", "project"]
65
+ self.readout_type = readout_type
66
+ self.patch_size = patch_size
67
+
68
+ self.projects = nn.ModuleList(
69
+ [
70
+ ConvModule(
71
+ in_channels=in_channels,
72
+ out_channels=out_channel,
73
+ kernel_size=1,
74
+ act_cfg=None,
75
+ )
76
+ for out_channel in out_channels
77
+ ]
78
+ )
79
+
80
+ self.resize_layers = nn.ModuleList(
81
+ [
82
+ nn.ConvTranspose2d(
83
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
84
+ ),
85
+ nn.ConvTranspose2d(
86
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
87
+ ),
88
+ nn.Identity(),
89
+ nn.Conv2d(
90
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
91
+ ),
92
+ ]
93
+ )
94
+ if self.readout_type == "project":
95
+ self.readout_projects = nn.ModuleList()
96
+ for _ in range(len(self.projects)):
97
+ self.readout_projects.append(
98
+ nn.Sequential(Linear(2 * in_channels, in_channels), build_activation_layer(dict(type="GELU")))
99
+ )
100
+
101
+ def forward(self, inputs):
102
+ assert isinstance(inputs, list)
103
+ out = []
104
+ for i, x in enumerate(inputs):
105
+ assert len(x) == 2
106
+ x, cls_token = x[0], x[1]
107
+ feature_shape = x.shape
108
+ if self.readout_type == "project":
109
+ x = x.flatten(2).permute((0, 2, 1))
110
+ readout = cls_token.unsqueeze(1).expand_as(x)
111
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
112
+ x = x.permute(0, 2, 1).reshape(feature_shape)
113
+ elif self.readout_type == "add":
114
+ x = x.flatten(2) + cls_token.unsqueeze(-1)
115
+ x = x.reshape(feature_shape)
116
+ else:
117
+ pass
118
+ x = self.projects[i](x)
119
+ x = self.resize_layers[i](x)
120
+ out.append(x)
121
+ return out
122
+
123
+
124
+ class PreActResidualConvUnit(BaseModule):
125
+ """ResidualConvUnit, pre-activate residual unit.
126
+ Args:
127
+ in_channels (int): number of channels in the input feature map.
128
+ act_cfg (dict): dictionary to construct and config activation layer.
129
+ norm_cfg (dict): dictionary to construct and config norm layer.
130
+ stride (int): stride of the first block. Default: 1
131
+ dilation (int): dilation rate for convs layers. Default: 1.
132
+ init_cfg (dict, optional): Initialization config dict. Default: None.
133
+ """
134
+
135
+ def __init__(self, in_channels, act_cfg, norm_cfg, stride=1, dilation=1, init_cfg=None):
136
+ super(PreActResidualConvUnit, self).__init__(init_cfg)
137
+
138
+ self.conv1 = ConvModule(
139
+ in_channels,
140
+ in_channels,
141
+ 3,
142
+ stride=stride,
143
+ padding=dilation,
144
+ dilation=dilation,
145
+ norm_cfg=norm_cfg,
146
+ act_cfg=act_cfg,
147
+ bias=False,
148
+ order=("act", "conv", "norm"),
149
+ )
150
+
151
+ self.conv2 = ConvModule(
152
+ in_channels,
153
+ in_channels,
154
+ 3,
155
+ padding=1,
156
+ norm_cfg=norm_cfg,
157
+ act_cfg=act_cfg,
158
+ bias=False,
159
+ order=("act", "conv", "norm"),
160
+ )
161
+
162
+ def forward(self, inputs):
163
+ inputs_ = inputs.clone()
164
+ x = self.conv1(inputs)
165
+ x = self.conv2(x)
166
+ return x + inputs_
167
+
168
+
169
+ class FeatureFusionBlock(BaseModule):
170
+ """FeatureFusionBlock, merge feature map from different stages.
171
+ Args:
172
+ in_channels (int): Input channels.
173
+ act_cfg (dict): The activation config for ResidualConvUnit.
174
+ norm_cfg (dict): Config dict for normalization layer.
175
+ expand (bool): Whether expand the channels in post process block.
176
+ Default: False.
177
+ align_corners (bool): align_corner setting for bilinear upsample.
178
+ Default: True.
179
+ init_cfg (dict, optional): Initialization config dict. Default: None.
180
+ """
181
+
182
+ def __init__(self, in_channels, act_cfg, norm_cfg, expand=False, align_corners=True, init_cfg=None):
183
+ super(FeatureFusionBlock, self).__init__(init_cfg)
184
+
185
+ self.in_channels = in_channels
186
+ self.expand = expand
187
+ self.align_corners = align_corners
188
+
189
+ self.out_channels = in_channels
190
+ if self.expand:
191
+ self.out_channels = in_channels // 2
192
+
193
+ self.project = ConvModule(self.in_channels, self.out_channels, kernel_size=1, act_cfg=None, bias=True)
194
+
195
+ self.res_conv_unit1 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
196
+ self.res_conv_unit2 = PreActResidualConvUnit(in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
197
+
198
+ def forward(self, *inputs):
199
+ x = inputs[0]
200
+ if len(inputs) == 2:
201
+ if x.shape != inputs[1].shape:
202
+ res = resize(inputs[1], size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
203
+ else:
204
+ res = inputs[1]
205
+ x = x + self.res_conv_unit1(res)
206
+ x = self.res_conv_unit2(x)
207
+ x = resize(x, scale_factor=2, mode="bilinear", align_corners=self.align_corners)
208
+ x = self.project(x)
209
+ return x
210
+
211
+
212
+ @HEADS.register_module()
213
+ class DPTHead(DepthBaseDecodeHead):
214
+ """Vision Transformers for Dense Prediction.
215
+ This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
216
+ Args:
217
+ embed_dims (int): The embed dimension of the ViT backbone.
218
+ Default: 768.
219
+ post_process_channels (List): Out channels of post process conv
220
+ layers. Default: [96, 192, 384, 768].
221
+ readout_type (str): Type of readout operation. Default: 'ignore'.
222
+ patch_size (int): The patch size. Default: 16.
223
+ expand_channels (bool): Whether expand the channels in post process
224
+ block. Default: False.
225
+ """
226
+
227
+ def __init__(
228
+ self,
229
+ embed_dims=768,
230
+ post_process_channels=[96, 192, 384, 768],
231
+ readout_type="ignore",
232
+ patch_size=16,
233
+ expand_channels=False,
234
+ **kwargs
235
+ ):
236
+ super(DPTHead, self).__init__(**kwargs)
237
+
238
+ self.in_channels = self.in_channels
239
+ self.expand_channels = expand_channels
240
+ self.reassemble_blocks = ReassembleBlocks(embed_dims, post_process_channels, readout_type, patch_size)
241
+
242
+ self.post_process_channels = [
243
+ channel * math.pow(2, i) if expand_channels else channel for i, channel in enumerate(post_process_channels)
244
+ ]
245
+ self.convs = nn.ModuleList()
246
+ for channel in self.post_process_channels:
247
+ self.convs.append(ConvModule(channel, self.channels, kernel_size=3, padding=1, act_cfg=None, bias=False))
248
+ self.fusion_blocks = nn.ModuleList()
249
+ for _ in range(len(self.convs)):
250
+ self.fusion_blocks.append(FeatureFusionBlock(self.channels, self.act_cfg, self.norm_cfg))
251
+ self.fusion_blocks[0].res_conv_unit1 = None
252
+ self.project = ConvModule(self.channels, self.channels, kernel_size=3, padding=1, norm_cfg=self.norm_cfg)
253
+ self.num_fusion_blocks = len(self.fusion_blocks)
254
+ self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
255
+ self.num_post_process_channels = len(self.post_process_channels)
256
+ assert self.num_fusion_blocks == self.num_reassemble_blocks
257
+ assert self.num_reassemble_blocks == self.num_post_process_channels
258
+ self.conv_depth = HeadDepth(self.channels)
259
+
260
+ def forward(self, inputs, img_metas):
261
+ assert len(inputs) == self.num_reassemble_blocks
262
+ x = [inp for inp in inputs]
263
+ x = self.reassemble_blocks(x)
264
+ x = [self.convs[i](feature) for i, feature in enumerate(x)]
265
+ out = self.fusion_blocks[0](x[-1])
266
+ for i in range(1, len(self.fusion_blocks)):
267
+ out = self.fusion_blocks[i](out, x[-(i + 1)])
268
+ out = self.project(out)
269
+ out = self.depth_pred(out)
270
+ return out
dinov2/eval/depth/models/decode_heads/linear_head.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from ...ops import resize
10
+ from ..builder import HEADS
11
+ from .decode_head import DepthBaseDecodeHead
12
+
13
+
14
+ @HEADS.register_module()
15
+ class BNHead(DepthBaseDecodeHead):
16
+ """Just a batchnorm."""
17
+
18
+ def __init__(self, input_transform="resize_concat", in_index=(0, 1, 2, 3), upsample=1, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.input_transform = input_transform
21
+ self.in_index = in_index
22
+ self.upsample = upsample
23
+ # self.bn = nn.SyncBatchNorm(self.in_channels)
24
+ if self.classify:
25
+ self.conv_depth = nn.Conv2d(self.channels, self.n_bins, kernel_size=1, padding=0, stride=1)
26
+ else:
27
+ self.conv_depth = nn.Conv2d(self.channels, 1, kernel_size=1, padding=0, stride=1)
28
+
29
+ def _transform_inputs(self, inputs):
30
+ """Transform inputs for decoder.
31
+ Args:
32
+ inputs (list[Tensor]): List of multi-level img features.
33
+ Returns:
34
+ Tensor: The transformed inputs
35
+ """
36
+
37
+ if "concat" in self.input_transform:
38
+ inputs = [inputs[i] for i in self.in_index]
39
+ if "resize" in self.input_transform:
40
+ inputs = [
41
+ resize(
42
+ input=x,
43
+ size=[s * self.upsample for s in inputs[0].shape[2:]],
44
+ mode="bilinear",
45
+ align_corners=self.align_corners,
46
+ )
47
+ for x in inputs
48
+ ]
49
+ inputs = torch.cat(inputs, dim=1)
50
+ elif self.input_transform == "multiple_select":
51
+ inputs = [inputs[i] for i in self.in_index]
52
+ else:
53
+ inputs = inputs[self.in_index]
54
+
55
+ return inputs
56
+
57
+ def _forward_feature(self, inputs, img_metas=None, **kwargs):
58
+ """Forward function for feature maps before classifying each pixel with
59
+ ``self.cls_seg`` fc.
60
+ Args:
61
+ inputs (list[Tensor]): List of multi-level img features.
62
+ Returns:
63
+ feats (Tensor): A tensor of shape (batch_size, self.channels,
64
+ H, W) which is feature map for last layer of decoder head.
65
+ """
66
+ # accept lists (for cls token)
67
+ inputs = list(inputs)
68
+ for i, x in enumerate(inputs):
69
+ if len(x) == 2:
70
+ x, cls_token = x[0], x[1]
71
+ if len(x.shape) == 2:
72
+ x = x[:, :, None, None]
73
+ cls_token = cls_token[:, :, None, None].expand_as(x)
74
+ inputs[i] = torch.cat((x, cls_token), 1)
75
+ else:
76
+ x = x[0]
77
+ if len(x.shape) == 2:
78
+ x = x[:, :, None, None]
79
+ inputs[i] = x
80
+ x = self._transform_inputs(inputs)
81
+ # feats = self.bn(x)
82
+ return x
83
+
84
+ def forward(self, inputs, img_metas=None, **kwargs):
85
+ """Forward function."""
86
+ output = self._forward_feature(inputs, img_metas=img_metas, **kwargs)
87
+ output = self.depth_pred(output)
88
+
89
+ return output