JianyuanWang commited on
Commit
febf487
·
1 Parent(s): 68f369a
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ examples/** filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # UV
98
+ # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ #uv.lock
102
+
103
+ # poetry
104
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
106
+ # commonly ignored for libraries.
107
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108
+ #poetry.lock
109
+
110
+ # pdm
111
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112
+ #pdm.lock
113
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114
+ # in version control.
115
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116
+ .pdm.toml
117
+ .pdm-python
118
+ .pdm-build/
119
+
120
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121
+ __pypackages__/
122
+
123
+ # Celery stuff
124
+ celerybeat-schedule
125
+ celerybeat.pid
126
+
127
+ # SageMath parsed files
128
+ *.sage.py
129
+
130
+ # Environments
131
+ .env
132
+ .venv
133
+ env/
134
+ venv/
135
+ ENV/
136
+ env.bak/
137
+ venv.bak/
138
+
139
+ # Spyder project settings
140
+ .spyderproject
141
+ .spyproject
142
+
143
+ # Rope project settings
144
+ .ropeproject
145
+
146
+ # mkdocs documentation
147
+ /site
148
+
149
+ # mypy
150
+ .mypy_cache/
151
+ .dmypy.json
152
+ dmypy.json
153
+
154
+ # Pyre type checker
155
+ .pyre/
156
+
157
+ # pytype static type analyzer
158
+ .pytype/
159
+
160
+ # Cython debug symbols
161
+ cython_debug/
162
+
163
+ # PyCharm
164
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
167
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168
+ #.idea/
169
+
170
+ # Ruff stuff:
171
+ .ruff_cache/
172
+
173
+ # PyPI configuration file
174
+ .pypirc
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Vggt
3
- emoji: 👁
4
- colorFrom: green
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.18.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-4.0
 
1
  ---
2
+ title: vggt
3
+ emoji: 🏆
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.17.1
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-4.0
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ import spaces
7
+ import sys
8
+ import os
9
+ import socket
10
+ import webbrowser
11
+ sys.path.append('vggt/')
12
+ import shutil
13
+ from datetime import datetime
14
+ from demo_hf import demo_fn
15
+ from omegaconf import DictConfig, OmegaConf
16
+ import glob
17
+ import gc
18
+ import time
19
+ from viser_fn import viser_wrapper
20
+
21
+
22
+ def get_free_port():
23
+ """Get a free port using socket."""
24
+ # return 80
25
+ # return 8080
26
+ # return 10088 # for debugging
27
+ # return 7860
28
+ # return 7888
29
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
30
+ s.bind(('', 0))
31
+ port = s.getsockname()[1]
32
+ return port
33
+
34
+
35
+
36
+
37
+ @spaces.GPU(duration=240)
38
+ def vggt_demo(
39
+ input_video,
40
+ input_image,
41
+ ):
42
+ start_time = time.time()
43
+ gc.collect()
44
+ torch.cuda.empty_cache()
45
+
46
+
47
+ debug = False
48
+
49
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
50
+ target_dir = f"input_images_{timestamp}"
51
+ if os.path.exists(target_dir):
52
+ shutil.rmtree(target_dir)
53
+
54
+ os.makedirs(target_dir)
55
+ target_dir_images = target_dir + "/images"
56
+ os.makedirs(target_dir_images)
57
+
58
+
59
+ if input_video is not None:
60
+ if not isinstance(input_video, str):
61
+ input_video = input_video["video"]["path"]
62
+
63
+ cfg_file = "config/base.yaml"
64
+ cfg = OmegaConf.load(cfg_file)
65
+
66
+ if input_image is not None:
67
+ input_image = sorted(input_image)
68
+ # recon_num = len(input_image)
69
+
70
+ # Copy files to the new directory
71
+ for file_name in input_image:
72
+ shutil.copy(file_name, target_dir_images)
73
+ elif input_video is not None:
74
+ vs = cv2.VideoCapture(input_video)
75
+
76
+ fps = vs.get(cv2.CAP_PROP_FPS)
77
+
78
+ frame_rate = 1
79
+ frame_interval = int(fps * frame_rate)
80
+
81
+ video_frame_num = 0
82
+ count = 0
83
+
84
+ while True:
85
+ (gotit, frame) = vs.read()
86
+ count +=1
87
+
88
+ if not gotit:
89
+ break
90
+
91
+ if count % frame_interval == 0:
92
+ cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
93
+ video_frame_num+=1
94
+
95
+ # recon_num = video_frame_num
96
+ # if recon_num<3:
97
+ # return None, "Please input at least three frames"
98
+ else:
99
+ return None, "Uploading not finished or Incorrect input format"
100
+
101
+
102
+ print(f"Files have been copied to {target_dir_images}")
103
+ cfg.SCENE_DIR = target_dir
104
+
105
+ predictions = demo_fn(cfg)
106
+
107
+ # Get a free port for viser
108
+ viser_port = get_free_port()
109
+
110
+ # Start viser visualization in a separate thread/process
111
+ viser_wrapper(predictions, port=viser_port)
112
+
113
+ del predictions
114
+ gc.collect()
115
+ torch.cuda.empty_cache()
116
+
117
+ print(input_image)
118
+ print(input_video)
119
+ end_time = time.time()
120
+ execution_time = end_time - start_time
121
+ print(f"Execution time: {execution_time} seconds")
122
+
123
+ # Return None for the 3D model (since we're using viser) and the viser URL
124
+ # viser_url = f"Viser visualization is ready at: http://localhost:{viser_port}"
125
+ # print(viser_url) # Debug print
126
+ return None, viser_port
127
+
128
+
129
+
130
+
131
+ statue_video = "examples/videos/statue_video.mp4"
132
+
133
+ apple_video = "examples/videos/apple_video.mp4"
134
+ british_museum_video = "examples/videos/british_museum_video.mp4"
135
+ cake_video = "examples/videos/cake_video.mp4"
136
+ bonsai_video = "examples/videos/bonsai_video.mp4"
137
+ face_video = "examples/videos/in2n_face_video.mp4"
138
+ counter_video = "examples/videos/in2n_counter_video.mp4"
139
+
140
+ horns_video = "examples/videos/llff_horns_video.mp4"
141
+ person_video = "examples/videos/in2n_person_video.mp4"
142
+
143
+ flower_video = "examples/videos/llff_flower_video.mp4"
144
+
145
+ fern_video = "examples/videos/llff_fern_video.mp4"
146
+
147
+ drums_video = "examples/videos/drums_video.mp4"
148
+
149
+ kitchen_video = "examples/videos/kitchen_video.mp4"
150
+
151
+ ###########################################################################################
152
+ apple_images = glob.glob(f'examples/apple/images/*')
153
+ bonsai_images = glob.glob(f'examples/bonsai/images/*')
154
+ cake_images = glob.glob(f'examples/cake/images/*')
155
+ british_museum_images = glob.glob(f'examples/british_museum/images/*')
156
+ face_images = glob.glob(f'examples/in2n_face/images/*')
157
+ counter_images = glob.glob(f'examples/in2n_counter/images/*')
158
+
159
+ horns_images = glob.glob(f'examples/llff_horns/images/*')
160
+
161
+ person_images = glob.glob(f'examples/in2n_person/images/*')
162
+ flower_images = glob.glob(f'examples/llff_flower/images/*')
163
+
164
+ fern_images = glob.glob(f'examples/llff_fern/images/*')
165
+ statue_images = glob.glob(f'examples/statue/images/*')
166
+
167
+ drums_images = glob.glob(f'examples/drums/images/*')
168
+ kitchen_images = glob.glob(f'examples/kitchen/images/*')
169
+
170
+
171
+
172
+ ###########################################################################################
173
+
174
+
175
+ with gr.Blocks() as demo:
176
+
177
+ gr.Markdown("""
178
+ # 🏛️ VGGT: Visual Geometry Grounded Transformer
179
+
180
+ <div style="font-size: 16px; line-height: 1.2;">
181
+ Alpha version (testing).
182
+ </div>
183
+ """)
184
+
185
+ with gr.Row():
186
+ with gr.Column(scale=1):
187
+ input_video = gr.Video(label="Upload Video", interactive=True)
188
+ input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
189
+
190
+
191
+ with gr.Column(scale=3):
192
+ viser_output = gr.HTML(
193
+ label="Viser Visualization",
194
+ value='''<div style="height: 520px; border: 1px solid #e0e0e0;
195
+ border-radius: 4px; padding: 16px;
196
+ display: flex; align-items: center;
197
+ justify-content: center">
198
+ 3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)
199
+ </div>'''
200
+ )
201
+
202
+ log_output = gr.Textbox(label="Log")
203
+
204
+ with gr.Row():
205
+ submit_btn = gr.Button("Reconstruct", scale=1)
206
+ clear_btn = gr.ClearButton([input_video, input_images, viser_output, log_output], scale=1) #Modified viser_output
207
+
208
+
209
+
210
+
211
+ examples = [
212
+ [flower_video, flower_images],
213
+ [kitchen_video, kitchen_images],
214
+ # [person_video, person_images],
215
+ # [statue_video, statue_images],
216
+ # [drums_video, drums_images],
217
+ [counter_video, counter_images],
218
+ [fern_video, fern_images],
219
+ [horns_video, horns_images],
220
+ # [apple_video, apple_images],
221
+ # [bonsai_video, bonsai_images],
222
+ ]
223
+
224
+ def process_example(video, images):
225
+ """Wrapper function to ensure outputs are properly captured"""
226
+ model_output, log = vggt_demo(video, images)
227
+
228
+ # viser_wrapper(predictions, port=log)
229
+ # Get the hostname - use the actual hostname or IP where the server is running
230
+ # hostname = socket.gethostname()
231
+
232
+ # Extract port from log
233
+ port = log
234
+
235
+ # Create the viser URL using the hostname
236
+ # viser_url = f"http://{hostname}:{port}"
237
+
238
+ viser_url = f"http://localhost:{log}"
239
+ print(f"Viser URL: {viser_url}")
240
+
241
+ # Create the iframe HTML code. Set width and height appropriately.
242
+ iframe_code = f'<iframe src="{viser_url}" width="100%" height="520px"></iframe>'
243
+
244
+
245
+ # Return the iframe code to update the gr.HTML component
246
+ return iframe_code, f"Visualization running at {viser_url}"
247
+
248
+
249
+ # TODO: move the selection of port outside of the demo function
250
+ # so that we can cache examples
251
+
252
+ gr.Examples(examples=examples,
253
+ inputs=[input_video, input_images],
254
+ outputs=[viser_output, log_output], # Output to viser_output
255
+ fn=process_example, # Use our wrapper function
256
+ cache_examples=False,
257
+ examples_per_page=50,
258
+ )
259
+
260
+ submit_btn.click(
261
+ process_example, # Use the same wrapper function
262
+ [input_video, input_images],
263
+ [viser_output, log_output], # Output to viser_output
264
+ # concurrency_limit=1
265
+ )
266
+
267
+ # demo.launch(debug=True, share=True)
268
+ # demo.launch(server_name="0.0.0.0", server_port=8082, debug=True, share=False)
269
+ # demo.queue(max_size=20).launch(show_error=True, share=True)
270
+ demo.queue(max_size=20).launch(show_error=True) #, share=True, server_port=7888, server_name="0.0.0.0")
271
+ # demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True)
272
+ ########################################################################################################################
clean_app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import torch
4
+ import numpy as np
5
+ import gradio as gr
6
+ import sys
7
+ import os
8
+ import socket
9
+ import webbrowser
10
+ sys.path.append('vggt/')
11
+ import shutil
12
+ from datetime import datetime
13
+ from demo_hf import demo_fn
14
+ from omegaconf import DictConfig, OmegaConf
15
+ import glob
16
+ import gc
17
+ import time
18
+ from viser_fn import viser_wrapper
19
+
20
+
21
+ def get_free_port():
22
+ """Get a free port using socket."""
23
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
24
+ s.bind(('', 0))
25
+ port = s.getsockname()[1]
26
+ return port
27
+
28
+ def vggt_demo(
29
+ input_video,
30
+ input_image,
31
+ ):
32
+ start_time = time.time()
33
+ gc.collect()
34
+ torch.cuda.empty_cache()
35
+
36
+
37
+ debug = False
38
+
39
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
40
+ target_dir = f"input_images_{timestamp}"
41
+ if os.path.exists(target_dir):
42
+ shutil.rmtree(target_dir)
43
+
44
+ os.makedirs(target_dir)
45
+ target_dir_images = target_dir + "/images"
46
+ os.makedirs(target_dir_images)
47
+
48
+
49
+ if input_video is not None:
50
+ if not isinstance(input_video, str):
51
+ input_video = input_video["video"]["path"]
52
+
53
+ cfg_file = "config/base.yaml"
54
+ cfg = OmegaConf.load(cfg_file)
55
+
56
+ if input_image is not None:
57
+ input_image = sorted(input_image)
58
+ # recon_num = len(input_image)
59
+
60
+ # Copy files to the new directory
61
+ for file_name in input_image:
62
+ shutil.copy(file_name, target_dir_images)
63
+ elif input_video is not None:
64
+ vs = cv2.VideoCapture(input_video)
65
+
66
+ fps = vs.get(cv2.CAP_PROP_FPS)
67
+
68
+ frame_rate = 1
69
+ frame_interval = int(fps * frame_rate)
70
+
71
+ video_frame_num = 0
72
+ count = 0
73
+
74
+ while True:
75
+ (gotit, frame) = vs.read()
76
+ count +=1
77
+
78
+ if not gotit:
79
+ break
80
+
81
+ if count % frame_interval == 0:
82
+ cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
83
+ video_frame_num+=1
84
+ else:
85
+ return None, "Uploading not finished or Incorrect input format"
86
+
87
+
88
+ print(f"Files have been copied to {target_dir_images}")
89
+ cfg.SCENE_DIR = target_dir
90
+
91
+ predictions = demo_fn(cfg)
92
+
93
+ # Get a free port for viser
94
+ viser_port = get_free_port()
95
+
96
+ # Start viser visualization in a separate thread/process
97
+ viser_wrapper(predictions, port=viser_port)
98
+
99
+ del predictions
100
+ gc.collect()
101
+ torch.cuda.empty_cache()
102
+
103
+ print(input_image)
104
+ print(input_video)
105
+ end_time = time.time()
106
+ execution_time = end_time - start_time
107
+ print(f"Execution time: {execution_time} seconds")
108
+ return None, viser_port
109
+
110
+
111
+
112
+
113
+ statue_video = "examples/videos/statue_video.mp4"
114
+
115
+ apple_video = "examples/videos/apple_video.mp4"
116
+ british_museum_video = "examples/videos/british_museum_video.mp4"
117
+ cake_video = "examples/videos/cake_video.mp4"
118
+ bonsai_video = "examples/videos/bonsai_video.mp4"
119
+ face_video = "examples/videos/in2n_face_video.mp4"
120
+ counter_video = "examples/videos/in2n_counter_video.mp4"
121
+
122
+ horns_video = "examples/videos/llff_horns_video.mp4"
123
+ person_video = "examples/videos/in2n_person_video.mp4"
124
+
125
+ flower_video = "examples/videos/llff_flower_video.mp4"
126
+
127
+ fern_video = "examples/videos/llff_fern_video.mp4"
128
+
129
+ drums_video = "examples/videos/drums_video.mp4"
130
+
131
+ kitchen_video = "examples/videos/kitchen_video.mp4"
132
+
133
+ ###########################################################################################
134
+ apple_images = glob.glob(f'examples/apple/images/*')
135
+ bonsai_images = glob.glob(f'examples/bonsai/images/*')
136
+ cake_images = glob.glob(f'examples/cake/images/*')
137
+ british_museum_images = glob.glob(f'examples/british_museum/images/*')
138
+ face_images = glob.glob(f'examples/in2n_face/images/*')
139
+ counter_images = glob.glob(f'examples/in2n_counter/images/*')
140
+
141
+ horns_images = glob.glob(f'examples/llff_horns/images/*')
142
+
143
+ person_images = glob.glob(f'examples/in2n_person/images/*')
144
+ flower_images = glob.glob(f'examples/llff_flower/images/*')
145
+
146
+ fern_images = glob.glob(f'examples/llff_fern/images/*')
147
+ statue_images = glob.glob(f'examples/statue/images/*')
148
+
149
+ drums_images = glob.glob(f'examples/drums/images/*')
150
+ kitchen_images = glob.glob(f'examples/kitchen/images/*')
151
+
152
+
153
+
154
+ ###########################################################################################
155
+
156
+
157
+ with gr.Blocks() as demo:
158
+
159
+ gr.Markdown("""
160
+ # 🏛️ VGGT: Visual Geometry Grounded Transformer
161
+
162
+ <div style="font-size: 16px; line-height: 1.2;">
163
+ Alpha version (testing).
164
+ </div>
165
+ """)
166
+
167
+ with gr.Row():
168
+ with gr.Column(scale=1):
169
+ input_video = gr.Video(label="Upload Video", interactive=True)
170
+ input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
171
+
172
+
173
+ with gr.Column(scale=3):
174
+ viser_output = gr.HTML(
175
+ label="Viser Visualization",
176
+ value='''<div style="height: 520px; border: 1px solid #e0e0e0;
177
+ border-radius: 4px; padding: 16px;
178
+ display: flex; align-items: center;
179
+ justify-content: center">
180
+ 3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)
181
+ </div>'''
182
+ )
183
+
184
+ log_output = gr.Textbox(label="Log")
185
+
186
+ with gr.Row():
187
+ submit_btn = gr.Button("Reconstruct", scale=1)
188
+ clear_btn = gr.ClearButton([input_video, input_images, viser_output, log_output], scale=1) #Modified viser_output
189
+
190
+
191
+
192
+
193
+ examples = [
194
+ [flower_video, flower_images],
195
+ [kitchen_video, kitchen_images],
196
+ [counter_video, counter_images],
197
+ [fern_video, fern_images],
198
+ [horns_video, horns_images],
199
+ ]
200
+
201
+ def process_example(video, images):
202
+ """Wrapper function to ensure outputs are properly captured"""
203
+ model_output, log = vggt_demo(video, images)
204
+
205
+ viser_url = f"http://localhost:{log}"
206
+ print(f"Viser URL: {viser_url}")
207
+
208
+ # Create the iframe HTML code. Set width and height appropriately.
209
+ iframe_code = f'<iframe src="{viser_url}" width="100%" height="520px"></iframe>'
210
+
211
+ return iframe_code, f"Visualization running at {viser_url}"
212
+
213
+ gr.Examples(examples=examples,
214
+ inputs=[input_video, input_images],
215
+ outputs=[viser_output, log_output], # Output to viser_output
216
+ fn=process_example, # Use our wrapper function
217
+ cache_examples=False,
218
+ examples_per_page=50,
219
+ )
220
+
221
+
222
+
223
+ submit_btn.click(
224
+ process_example, # Use the same wrapper function
225
+ [input_video, input_images],
226
+ [viser_output, log_output], # Output to viser_output
227
+ concurrency_limit=1
228
+ )
229
+ demo.queue(max_size=20).launch(show_error=True, share=True, server_port=7888, server_name="0.0.0.0")
config/base.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SCENE_DIR: examples/apple/
2
+ # examples/llff_horns_single/
3
+ # apple
4
+ # cake
5
+
6
+ _target_: vggt.models.vggt.VGGT #off3d.models.vggt.vggt.VGGT
7
+
8
+ num_register_tokens: 4 # 0 for no register tokens
9
+ ffn_layer: "mlp"
10
+ qk_norm: False # NOTE: is this correct?
11
+ patch_size: 14
12
+ init_values: 0.01
13
+
14
+ AGGREGATOR:
15
+ _target_: vggt.models.aggregator.Aggregator
16
+ patch_embed_by_conv: False
17
+ image_size: 518
18
+ use_checkpoint: True
19
+ use_reentrant: False
20
+ decoder_load_dino: False
21
+ backbone_qk_norm: False
22
+ aa_block_kwargs:
23
+ dim: 1024
24
+ num_heads: 16
25
+ mlp_ratio: 4
26
+ qkv_bias: True
27
+ proj_bias: True
28
+ ffn_bias: True
29
+ drop: 0.0
30
+ attn_drop: 0.0
31
+ init_values: 0.01
32
+ drop_path: 0.0
33
+ fused_attn: True
34
+ qk_norm: True
35
+ rope_freq: 100
36
+
37
+
38
+ CameraHead:
39
+ _target_: vggt.heads.camera_head.CameraHead #off3d.models.vggt.camera_head.CameraHead
40
+ pose_encoding_type: "absT_quaR_FoV"
41
+ new_trunk: True
42
+ trunk_depth: 4
43
+ # proj_dim: 768
44
+ qk_norm: True
45
+ init_values: 0.01
46
+ act_dict:
47
+ trans_act: "linear"
48
+ quat_act: "linear"
49
+ fl_act: "linear"
50
+ loss_kwargs:
51
+ loss_type: "l1"
52
+ gamma: 0.6
53
+
54
+
55
+ PointHead:
56
+ _target_: vggt.heads.dpt_head.DPTHead #off3d.models.vggt.dpt_head.DPTHead
57
+ # _target_: off3d.models.vggt.linear_head.LinearHead
58
+ dim_in: 2048
59
+ shallow_conv: False
60
+ normalize_act: "inv_log"
61
+ pos_embed: True
62
+ loss_kwargs:
63
+ gradient_loss: "normal"
64
+ # gradient_loss: "grad"
65
+ normalize_pred: False
66
+ valid_range: 0.98
67
+ gamma: 1.0
68
+ camera_centric_reg: -1.0
69
+ all_mean: True
70
+
71
+ DepthHead: null
72
+ # _target_: vggt.heads.dpt_head.DPTHead #off3d.models.vggt.dpt_head.DPTHead
73
+ # # _target_: off3d.models.vggt.linear_head.LinearHead
74
+ # dim_in: 2048
75
+ # patch_size: ${patch_size}
76
+ # output_dim: 2
77
+ # normalize_act: "exp" # or just relu?
78
+ # normalize_act_conf: "expp1"
79
+ # pos_embed: True
80
+ # loss_kwargs:
81
+ # loss_type: "conf"
82
+ # predict_disparity: False # or True
83
+ # gradient_loss: "grad"
84
+ # valid_range: 0.98
85
+ # gamma: 1.0
86
+ # all_mean: True
87
+
88
+ MatchHead: null
89
+ TrackHead: null
90
+
91
+
92
+
93
+ hydra:
94
+ output_subdir: NULL
95
+ run:
96
+ dir: .
demo_hf.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import hydra
2
+ import torch
3
+ import os
4
+ from hydra.utils import instantiate
5
+ from omegaconf import DictConfig
6
+ from PIL import Image
7
+ from torchvision import transforms as TF
8
+ import glob
9
+ from vggt.utils.pose_enc import pose_encoding_to_extri_intri
10
+ from viser_fn import viser_wrapper
11
+
12
+
13
+ # @hydra.main(config_path="config", config_name="base")
14
+ def demo_fn(cfg: DictConfig) -> None:
15
+ print(cfg)
16
+ model = instantiate(cfg, _recursive_=False)
17
+
18
+ if not torch.cuda.is_available():
19
+ raise ValueError("CUDA is not available. Check your environment.")
20
+
21
+ device = "cuda"
22
+ model = model.to(device)
23
+
24
+ _VGGT_URL = "https://huggingface.co/facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt"
25
+
26
+ # Reload model
27
+ pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL)
28
+
29
+ if "model" in pretrain_model:
30
+ model_dict = pretrain_model["model"]
31
+ model.load_state_dict(model_dict, strict=False)
32
+ else:
33
+ model.load_state_dict(pretrain_model, strict=True)
34
+
35
+
36
+ # batch = torch.load("/fsx-repligen/jianyuan/cvpr2025_ckpts/batch.pth")
37
+ # y_hat_raw = torch.load("/fsx-repligen/jianyuan/cvpr2025_ckpts/y_hat.pth")
38
+
39
+
40
+ image_list = glob.glob(os.path.join(cfg.SCENE_DIR, "images", "*"))
41
+ image_list = sorted(image_list)
42
+ images = load_and_preprocess_images(image_list)
43
+ images = images[None].to(device)
44
+
45
+
46
+ batch = {"images": images}
47
+
48
+ with torch.no_grad():
49
+ with torch.cuda.amp.autocast(dtype=torch.float16):
50
+ y_hat = model(batch)
51
+
52
+
53
+ last_pred_pose_enc = y_hat["pred_extrinsic_list"][-1]
54
+ pose_encoding_type = cfg.CameraHead.pose_encoding_type
55
+
56
+ last_pred_extrinsic, _ = pose_encoding_to_extri_intri(last_pred_pose_enc.detach(), None, pose_encoding_type=pose_encoding_type, build_intrinsics=False)
57
+
58
+ y_hat["last_pred_extrinsic"] = last_pred_extrinsic
59
+
60
+
61
+ for key in y_hat.keys():
62
+ if isinstance(y_hat[key], torch.Tensor):
63
+ y_hat[key] = y_hat[key].cpu().numpy()
64
+
65
+ return y_hat
66
+
67
+
68
+
69
+ def load_and_preprocess_images(image_path_list):
70
+ # Check for empty list
71
+ if len(image_path_list) == 0:
72
+ raise ValueError("At least 1 image is required")
73
+
74
+ # 1. load images as RGB
75
+ # 2. resize images to (518, X, 3), where X is the resized width and X should be divisible by 14
76
+ # 3. normalize images to (0, 1)
77
+ # 4. concatenate images to (N, 3, 518, X), where N is the number of images
78
+ images = []
79
+ shapes = set()
80
+ to_tensor = TF.ToTensor()
81
+
82
+ # First process all images and collect their shapes
83
+ for image_path in image_path_list:
84
+ img = Image.open(image_path).convert("RGB")
85
+ width, height = img.size
86
+ new_width = 518
87
+
88
+ # Calculate height maintaining aspect ratio, divisible by 14
89
+ new_height = round(height * (new_width / width) / 14) * 14
90
+
91
+ # Resize with new dimensions (width, height)
92
+
93
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
94
+ img = to_tensor(img) # Convert to tensor (0, 1)
95
+
96
+ # Center crop height if it's larger than 518
97
+
98
+ if new_height > 518:
99
+ start_y = (new_height - 518) // 2
100
+ img = img[:, start_y:start_y + 518, :]
101
+
102
+ shapes.add((img.shape[1], img.shape[2]))
103
+ images.append(img)
104
+
105
+ # Check if we have different shapes
106
+ if len(shapes) > 1:
107
+ print(f"Warning: Found images with different shapes: {shapes}")
108
+ # Find maximum dimensions
109
+ max_height = max(shape[0] for shape in shapes)
110
+ max_width = max(shape[1] for shape in shapes)
111
+
112
+ # Pad images if necessary
113
+ padded_images = []
114
+ for img in images:
115
+ h_padding = max_height - img.shape[1]
116
+ w_padding = max_width - img.shape[2]
117
+
118
+ if h_padding > 0 or w_padding > 0:
119
+ pad_top = h_padding // 2
120
+ pad_bottom = h_padding - pad_top
121
+ pad_left = w_padding // 2
122
+ pad_right = w_padding - pad_left
123
+
124
+ img = torch.nn.functional.pad(
125
+ img,
126
+ (pad_left, pad_right, pad_top, pad_bottom),
127
+ mode='constant',
128
+ value=1.0
129
+ )
130
+ padded_images.append(img)
131
+ images = padded_images
132
+
133
+
134
+ images = torch.stack(images) # concatenate images
135
+
136
+ # Ensure correct shape when single image
137
+ if len(image_path_list) == 1:
138
+ # Verify shape is (1, C, H, W)
139
+ if images.dim() == 3:
140
+ images = images.unsqueeze(0)
141
+
142
+ return images
143
+
144
+
145
+ # if __name__ == "__main__":
146
+ # y_hat = demo_fn()
147
+ # # viser_wrapper(y_hat, port=8080)
148
+
149
+
gradio_util.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import os
3
+
4
+ import trimesh
5
+ import open3d as o3d
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import matplotlib
10
+ from scipy.spatial.transform import Rotation
11
+
12
+ print("Successfully imported the packages for Gradio visualization")
13
+ except:
14
+ print(
15
+ f"Failed to import packages for Gradio visualization. Please disable gradio visualization"
16
+ )
17
+
18
+
19
+ def visualize_by_gradio(glbfile):
20
+ """
21
+ Set up and launch a Gradio interface to visualize a GLB file.
22
+
23
+ Args:
24
+ glbfile (str): Path to the GLB file to be visualized.
25
+ """
26
+
27
+ def load_glb_file(glb_path):
28
+ # Check if the file exists and return the path or error message
29
+ if os.path.exists(glb_path):
30
+ return glb_path, "3D Model Loaded Successfully"
31
+ else:
32
+ return None, "File not found"
33
+
34
+ # Load the GLB file initially to check if it's valid
35
+ initial_model, log_message = load_glb_file(glbfile)
36
+
37
+ # Create the Gradio interface
38
+ with gr.Blocks() as demo:
39
+ gr.Markdown("# GLB File Viewer")
40
+
41
+ # 3D Model viewer component
42
+ model_viewer = gr.Model3D(
43
+ label="3D Model Viewer", height=600, value=initial_model
44
+ )
45
+
46
+ # Textbox for log output
47
+ log_output = gr.Textbox(label="Log", lines=2, value=log_message)
48
+
49
+ # Launch the Gradio interface
50
+ demo.launch(share=True)
51
+
52
+
53
+ def vggsfm_predictions_to_glb(predictions) -> trimesh.Scene:
54
+ """
55
+ Converts VGG SFM predictions to a 3D scene represented as a GLB.
56
+
57
+ Args:
58
+ predictions (dict): A dictionary containing model predictions.
59
+
60
+ Returns:
61
+ trimesh.Scene: A 3D scene object.
62
+ """
63
+ # Convert predictions to numpy arrays
64
+ vertices_3d = predictions["points3D"].cpu().numpy()
65
+ colors_rgb = (predictions["points3D_rgb"].cpu().numpy() * 255).astype(
66
+ np.uint8
67
+ )
68
+
69
+
70
+ if True:
71
+ pcd = o3d.geometry.PointCloud()
72
+ pcd.points = o3d.utility.Vector3dVector(vertices_3d)
73
+ pcd.colors = o3d.utility.Vector3dVector(colors_rgb)
74
+
75
+ cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=1.0)
76
+ filtered_pcd = pcd.select_by_index(ind)
77
+
78
+ print(f"Filter out {len(vertices_3d) - len(filtered_pcd.points)} 3D points")
79
+ vertices_3d = np.asarray(filtered_pcd.points)
80
+ colors_rgb = np.asarray(filtered_pcd.colors).astype(np.uint8)
81
+
82
+
83
+
84
+ camera_matrices = predictions["extrinsics_opencv"].cpu().numpy()
85
+
86
+ # Calculate the 5th and 95th percentiles along each axis
87
+ lower_percentile = np.percentile(vertices_3d, 5, axis=0)
88
+ upper_percentile = np.percentile(vertices_3d, 95, axis=0)
89
+
90
+ # Calculate the diagonal length of the percentile bounding box
91
+ scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
92
+
93
+ colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
94
+
95
+ # Initialize a 3D scene
96
+ scene_3d = trimesh.Scene()
97
+
98
+ # Add point cloud data to the scene
99
+ point_cloud_data = trimesh.PointCloud(
100
+ vertices=vertices_3d, colors=colors_rgb
101
+ )
102
+
103
+ scene_3d.add_geometry(point_cloud_data)
104
+
105
+ # Prepare 4x4 matrices for camera extrinsics
106
+ num_cameras = len(camera_matrices)
107
+ extrinsics_matrices = np.zeros((num_cameras, 4, 4))
108
+ extrinsics_matrices[:, :3, :4] = camera_matrices
109
+ extrinsics_matrices[:, 3, 3] = 1
110
+
111
+ # Add camera models to the scene
112
+ for i in range(num_cameras):
113
+ world_to_camera = extrinsics_matrices[i]
114
+ camera_to_world = np.linalg.inv(world_to_camera)
115
+ rgba_color = colormap(i / num_cameras)
116
+ current_color = tuple(int(255 * x) for x in rgba_color[:3])
117
+
118
+ integrate_camera_into_scene(
119
+ scene_3d, camera_to_world, current_color, scene_scale
120
+ )
121
+
122
+ # Align scene to the observation of the first camera
123
+ scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
124
+
125
+ return scene_3d
126
+
127
+
128
+ def apply_scene_alignment(
129
+ scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray
130
+ ) -> trimesh.Scene:
131
+ """
132
+ Aligns the 3D scene based on the extrinsics of the first camera.
133
+
134
+ Args:
135
+ scene_3d (trimesh.Scene): The 3D scene to be aligned.
136
+ extrinsics_matrices (np.ndarray): Camera extrinsic matrices.
137
+
138
+ Returns:
139
+ trimesh.Scene: Aligned 3D scene.
140
+ """
141
+ # Set transformations for scene alignment
142
+ opengl_conversion_matrix = get_opengl_conversion_matrix()
143
+
144
+ # Rotation matrix for alignment (180 degrees around the y-axis)
145
+ align_rotation = np.eye(4)
146
+ align_rotation[:3, :3] = Rotation.from_euler(
147
+ "y", 180, degrees=True
148
+ ).as_matrix()
149
+
150
+ # Apply transformation
151
+ initial_transformation = (
152
+ np.linalg.inv(extrinsics_matrices[0])
153
+ @ opengl_conversion_matrix
154
+ @ align_rotation
155
+ )
156
+ scene_3d.apply_transform(initial_transformation)
157
+ return scene_3d
158
+
159
+
160
+ def integrate_camera_into_scene(
161
+ scene: trimesh.Scene,
162
+ transform: np.ndarray,
163
+ face_colors: tuple,
164
+ scene_scale: float,
165
+ ):
166
+ """
167
+ Integrates a fake camera mesh into the 3D scene.
168
+
169
+ Args:
170
+ scene (trimesh.Scene): The 3D scene to add the camera model.
171
+ transform (np.ndarray): Transformation matrix for camera positioning.
172
+ face_colors (tuple): Color of the camera face.
173
+ scene_scale (float): Scale of the scene.
174
+ """
175
+
176
+ cam_width = scene_scale * 0.05
177
+ cam_height = scene_scale * 0.1
178
+
179
+ # Create cone shape for camera
180
+ rot_45_degree = np.eye(4)
181
+ rot_45_degree[:3, :3] = Rotation.from_euler(
182
+ "z", 45, degrees=True
183
+ ).as_matrix()
184
+ rot_45_degree[2, 3] = -cam_height
185
+
186
+ opengl_transform = get_opengl_conversion_matrix()
187
+ # Combine transformations
188
+ complete_transform = transform @ opengl_transform @ rot_45_degree
189
+ camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
190
+
191
+ # Generate mesh for the camera
192
+ slight_rotation = np.eye(4)
193
+ slight_rotation[:3, :3] = Rotation.from_euler(
194
+ "z", 2, degrees=True
195
+ ).as_matrix()
196
+
197
+ vertices_combined = np.concatenate(
198
+ [
199
+ camera_cone_shape.vertices,
200
+ 0.95 * camera_cone_shape.vertices,
201
+ transform_points(slight_rotation, camera_cone_shape.vertices),
202
+ ]
203
+ )
204
+ vertices_transformed = transform_points(
205
+ complete_transform, vertices_combined
206
+ )
207
+
208
+ mesh_faces = compute_camera_faces(camera_cone_shape)
209
+
210
+ # Add the camera mesh to the scene
211
+ camera_mesh = trimesh.Trimesh(
212
+ vertices=vertices_transformed, faces=mesh_faces
213
+ )
214
+ camera_mesh.visual.face_colors[:, :3] = face_colors
215
+ scene.add_geometry(camera_mesh)
216
+
217
+
218
+ def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
219
+ """
220
+ Computes the faces for the camera mesh.
221
+
222
+ Args:
223
+ cone_shape (trimesh.Trimesh): The shape of the camera cone.
224
+
225
+ Returns:
226
+ np.ndarray: Array of faces for the camera mesh.
227
+ """
228
+ # Create pseudo cameras
229
+ faces_list = []
230
+ num_vertices_cone = len(cone_shape.vertices)
231
+
232
+ for face in cone_shape.faces:
233
+ if 0 in face:
234
+ continue
235
+ v1, v2, v3 = face
236
+ v1_offset, v2_offset, v3_offset = face + num_vertices_cone
237
+ v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
238
+
239
+ faces_list.extend(
240
+ [
241
+ (v1, v2, v2_offset),
242
+ (v1, v1_offset, v3),
243
+ (v3_offset, v2, v3),
244
+ (v1, v2, v2_offset_2),
245
+ (v1, v1_offset_2, v3),
246
+ (v3_offset_2, v2, v3),
247
+ ]
248
+ )
249
+
250
+ faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
251
+ return np.array(faces_list)
252
+
253
+
254
+ def transform_points(
255
+ transformation: np.ndarray, points: np.ndarray, dim: int = None
256
+ ) -> np.ndarray:
257
+ """
258
+ Applies a 4x4 transformation to a set of points.
259
+
260
+ Args:
261
+ transformation (np.ndarray): Transformation matrix.
262
+ points (np.ndarray): Points to be transformed.
263
+ dim (int, optional): Dimension for reshaping the result.
264
+
265
+ Returns:
266
+ np.ndarray: Transformed points.
267
+ """
268
+ points = np.asarray(points)
269
+ initial_shape = points.shape[:-1]
270
+ dim = dim or points.shape[-1]
271
+
272
+ # Apply transformation
273
+ transformation = transformation.swapaxes(
274
+ -1, -2
275
+ ) # Transpose the transformation matrix
276
+ points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
277
+
278
+ # Reshape the result
279
+ result = points[..., :dim].reshape(*initial_shape, dim)
280
+ return result
281
+
282
+
283
+ def get_opengl_conversion_matrix() -> np.ndarray:
284
+ """
285
+ Constructs and returns the OpenGL conversion matrix.
286
+
287
+ Returns:
288
+ numpy.ndarray: A 4x4 OpenGL conversion matrix.
289
+ """
290
+ # Create an identity matrix
291
+ matrix = np.identity(4)
292
+
293
+ # Flip the y and z axes
294
+ matrix[1, 1] = -1
295
+ matrix[2, 2] = -1
296
+
297
+ return matrix
requirements.txt ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.4.0
2
+ torchvision==0.19.0
3
+ hydra-core==1.3.2
4
+ scipy
5
+ omegaconf
6
+ opencv-python
7
+ einops
8
+ numpy==1.26.3
9
+ viser
10
+
11
+
12
+
13
+
14
+ # accelerate==0.24.0
15
+ # git+https://github.com/cvg/LightGlue.git#egg=LightGlue
16
+ # pycolmap==0.6.1
17
+ # https://huggingface.co/facebook/VGGSfM/resolve/main/poselib-2.0.2-cp310-cp310-linux_x86_64.whl
18
+ # trimesh
19
+ # open3d
20
+
21
+ # hydra-core==1.3.2
22
+ # scipy
23
+ # omegaconf
24
+ # opencv-python
25
+ # einops
26
+ # numpy==1.26.3
27
+ # trimesh
28
+ # open3d
vggt/heads/camera_head.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from hydra.utils import instantiate
14
+
15
+ from vggt.layers.block import Block
16
+
17
+ from vggt.layers import Mlp
18
+ from vggt.heads.utils import PoseEmbedding
19
+ from vggt.heads.head_act import activate_pose
20
+
21
+ def modulate(x, shift, scale):
22
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
23
+ return x * (1 + scale) + shift
24
+
25
+
26
+
27
+ class CameraHead(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim_in=2048,
31
+ patch_size=14,
32
+ qk_norm=False,
33
+ trunk_depth=4,
34
+ new_trunk=True,
35
+ update_new_trunk_tokens=False,
36
+ pose_encoding_type="absT_quaR_FoV",
37
+ proj_dim=-1,
38
+ num_heads=16,
39
+ mlp_ratio=4,
40
+ init_values=None,
41
+ act_dict=None,
42
+ **kwargs,
43
+ ):
44
+ super().__init__()
45
+
46
+ # Three types:
47
+ # 1. Linear projection
48
+ # 2. New trunk
49
+ # 3. Old trunk
50
+
51
+ self.new_trunk = new_trunk
52
+ if pose_encoding_type=="absT_quaR_FoV":
53
+ self.target_dim = 9
54
+ elif pose_encoding_type=="absT_quaR_OneFLM1":
55
+ self.target_dim = 8
56
+ else:
57
+ raise ValueError(f"Unsupported pose encoding type: {pose_encoding_type}")
58
+
59
+ self.update_new_trunk_tokens = update_new_trunk_tokens
60
+ self.act_dict = act_dict
61
+ self.trunk_depth = trunk_depth
62
+
63
+ self.token_norm = nn.LayerNorm(dim_in)
64
+
65
+ if proj_dim > 0:
66
+ self.proj = nn.Linear(dim_in, proj_dim)
67
+ dim_in = proj_dim
68
+ else:
69
+ self.proj = nn.Identity()
70
+
71
+ if self.trunk_depth <0:
72
+ self.pose_branch = nn.Linear(dim_in, self.target_dim)
73
+ else:
74
+ self.trunk = nn.Sequential(
75
+ *[
76
+ Block(
77
+ dim=dim_in,
78
+ num_heads=num_heads,
79
+ mlp_ratio=mlp_ratio,
80
+ qk_norm=qk_norm,
81
+ init_values=init_values,
82
+ )
83
+ for _ in range(trunk_depth)
84
+ ]
85
+ )
86
+ self.trunk_norm = nn.LayerNorm(dim_in)
87
+
88
+ if self.new_trunk:
89
+ # TODO: self.empty_pose_tokens -> BxSxC
90
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
91
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
92
+
93
+ self.poseLN_modulation = nn.Sequential(
94
+ nn.SiLU(),
95
+ nn.Linear(dim_in, 3 * dim_in, bias=True)
96
+ )
97
+
98
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
99
+ self.pose_branch = Mlp(
100
+ in_features=dim_in,
101
+ hidden_features=dim_in // 2,
102
+ out_features=self.target_dim,
103
+ drop=0,
104
+ )
105
+ else:
106
+ self.ffeat_norm = nn.LayerNorm(dim_in)
107
+ self.pose_branch = Mlp(
108
+ in_features=dim_in,
109
+ hidden_features=dim_in * 2,
110
+ out_features=dim_in + self.target_dim,
111
+ drop=0,
112
+ )
113
+
114
+ self.ffeat_updater = nn.Sequential(
115
+ nn.Linear(dim_in, dim_in), nn.GELU()
116
+ )
117
+
118
+ # sine and cosine embed for camera parameters
119
+ self.embed_pose = PoseEmbedding(
120
+ target_dim=self.target_dim,
121
+ n_harmonic_functions=(dim_in // self.target_dim) // 2,
122
+ append_input=False,
123
+ )
124
+ self.embed_pose_proj = nn.Linear(self.embed_pose.out_dim, dim_in)
125
+
126
+
127
+ def forward(self, aggregated_tokens_list, batch, patch_start_idx, iters=4,):
128
+ """
129
+ """
130
+ tokens = aggregated_tokens_list[-1]
131
+ # only use the Pose token for camera prediction
132
+ pose_tokens = tokens[:, :, 0]
133
+ pose_tokens = self.token_norm(pose_tokens)
134
+ pose_tokens = self.proj(pose_tokens)
135
+
136
+ B, S, C = pose_tokens.shape
137
+
138
+ if self.trunk_depth < 0:
139
+ pred_pose_enc = self.pose_branch(pose_tokens)
140
+ pred_pose_enc_list = [activate_pose(pred_pose_enc, **self.act_dict)]
141
+ elif self.new_trunk:
142
+ pred_pose_enc_list = self.new_trunk_fn(pose_tokens, iters)
143
+ else:
144
+ pred_pose_enc_list = self.old_trunk_fn(pose_tokens, iters)
145
+
146
+
147
+ # TODO add act here
148
+ return pred_pose_enc_list
149
+
150
+
151
+ def new_trunk_fn(self, pose_tokens, iters):
152
+ B, S, C = pose_tokens.shape
153
+
154
+ pred_pose_enc = None
155
+ pose_tokens_init = pose_tokens.clone()
156
+
157
+ pred_pose_enc_list = []
158
+
159
+ for iter_num in range(iters):
160
+ if pred_pose_enc is None:
161
+ # model_input = self.empty_representation BxSxC
162
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
163
+ else:
164
+ pred_pose_enc = pred_pose_enc.detach()
165
+ module_input = self.embed_pose(pred_pose_enc)
166
+
167
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
168
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
169
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
170
+
171
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
172
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
173
+
174
+ if pred_pose_enc is None:
175
+ pred_pose_enc = pred_pose_enc_delta
176
+ else:
177
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
178
+
179
+ if self.update_new_trunk_tokens:
180
+ pose_tokens = pose_tokens_modulated + pose_tokens_init
181
+
182
+ pred_pose_enc_list.append(activate_pose(pred_pose_enc, **self.act_dict))
183
+
184
+ return pred_pose_enc_list
185
+
186
+
187
+ def old_trunk_fn(self, pose_tokens, iters):
188
+ B, S, C = pose_tokens.shape
189
+
190
+ pred_pose_enc = torch.zeros(B, S, self.target_dim).to(
191
+ pose_tokens.device
192
+ )
193
+
194
+ pose_tokens_init = pose_tokens.clone()
195
+
196
+ pred_pose_enc_list = []
197
+
198
+ for iter_num in range(iters):
199
+ pred_pose_enc = pred_pose_enc.detach()
200
+
201
+ # Embed the camera parameters and add to pose_tokens
202
+ pose_embed = self.embed_pose_proj(self.embed_pose(pred_pose_enc))
203
+ pose_tokens = pose_tokens + pose_embed
204
+
205
+ # Run trunk transformers on pose_tokens
206
+ pose_tokens = self.trunk(pose_tokens)
207
+
208
+ # Predict the delta feat and pose encoding at each iteration
209
+ delta = self.pose_branch(self.trunk_norm(pose_tokens))
210
+ delta_pred_pose_enc = delta[..., : self.target_dim]
211
+ delta_feat = delta[..., self.target_dim :]
212
+
213
+ pose_tokens = self.ffeat_updater(self.ffeat_norm(delta_feat)) + pose_tokens
214
+
215
+ pred_pose_enc = pred_pose_enc + delta_pred_pose_enc
216
+ pose_tokens = (pose_tokens + pose_tokens_init) / 2
217
+ pred_pose_enc_list.append(activate_pose(pred_pose_enc, **self.act_dict))
218
+
219
+ return pred_pose_enc_list
220
+
vggt/heads/dpt_head.py ADDED
@@ -0,0 +1,521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # linear head implementation for DUST3R
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from .head_act import activate_head
13
+ from .utils import normalized_view_plane_uv, HarmonicEmbedding, position_grid_to_embed
14
+
15
+ class DPTHead(nn.Module):
16
+ """
17
+ """
18
+ def __init__(self,
19
+ dim_in,
20
+ patch_size = 14,
21
+ output_dim = 4,
22
+ normalize_act="inv_log",
23
+ normalize_act_conf = "expp1",
24
+ features=256,
25
+ use_bn=False,
26
+ use_clstoken=False,
27
+ out_channels=[256, 512, 1024, 1024],
28
+ intermediate_layer_idx=[4, 11, 17, 23],
29
+ shared_norm = True,
30
+ add_rgb = False,
31
+ head_use_checkpoint=False,
32
+ groups=1,
33
+ shallow_conv=False,
34
+ load_da_str=None,
35
+ dpt_layer_norm=False,
36
+ pos_embed = False,
37
+ feature_only = False,
38
+ down_ratio = 1,
39
+ **kwargs,
40
+ ):
41
+ super(DPTHead, self).__init__()
42
+
43
+ in_channels = dim_in
44
+ self.add_rgb = add_rgb
45
+ self.patch_size = patch_size
46
+ self.intermediate_layer_idx = intermediate_layer_idx
47
+ self.shared_norm = shared_norm
48
+ self.normalize_act = normalize_act
49
+ self.normalize_act_conf = normalize_act_conf
50
+ self.head_use_checkpoint = head_use_checkpoint
51
+ self.pos_embed = pos_embed
52
+ self.feature_only = feature_only
53
+ self.down_ratio = down_ratio
54
+
55
+ # if self.pos_embed:
56
+ # self.pose_embed_fn_64 = HarmonicEmbedding(n_harmonic_functions=64, omega_0=1.0, logspace=True, append_input=False)
57
+ # self.pose_embed_fn_128 = HarmonicEmbedding(n_harmonic_functions=128, omega_0=1.0, logspace=True, append_input=False)
58
+ # self.pose_embed_fn_256 = HarmonicEmbedding(n_harmonic_functions=256, omega_0=1.0, logspace=True, append_input=False)
59
+ # self.pose_embed_fn_512 = HarmonicEmbedding(n_harmonic_functions=512, omega_0=1.0, logspace=True, append_input=False)
60
+ # self.pose_embed_fn_1024 = HarmonicEmbedding(n_harmonic_functions=1024, omega_0=1.0, logspace=True, append_input=False)
61
+
62
+ if self.shared_norm:
63
+ self.norm = nn.LayerNorm(in_channels)
64
+ else:
65
+ self.norm = nn.ModuleList([nn.LayerNorm(in_channels) for _ in range(len(self.intermediate_layer_idx))])
66
+
67
+ self.use_clstoken = use_clstoken
68
+
69
+ self.projects = nn.ModuleList([
70
+ nn.Conv2d(
71
+ in_channels=in_channels,
72
+ out_channels=out_channel,
73
+ kernel_size=1,
74
+ stride=1,
75
+ padding=0,
76
+ ) for out_channel in out_channels
77
+ ])
78
+
79
+ self.resize_layers = nn.ModuleList([
80
+ nn.ConvTranspose2d(
81
+ in_channels=out_channels[0],
82
+ out_channels=out_channels[0],
83
+ kernel_size=4,
84
+ stride=4,
85
+ padding=0),
86
+ nn.ConvTranspose2d(
87
+ in_channels=out_channels[1],
88
+ out_channels=out_channels[1],
89
+ kernel_size=2,
90
+ stride=2,
91
+ padding=0),
92
+ nn.Identity(),
93
+ nn.Conv2d(
94
+ in_channels=out_channels[3],
95
+ out_channels=out_channels[3],
96
+ kernel_size=3,
97
+ stride=2,
98
+ padding=1)
99
+ ])
100
+
101
+ if use_clstoken:
102
+ raise ValueError("CLS token is not supported for DPT head Now")
103
+ self.readout_projects = nn.ModuleList()
104
+ for _ in range(len(self.projects)):
105
+ self.readout_projects.append(
106
+ nn.Sequential(
107
+ nn.Linear(2 * in_channels, in_channels),
108
+ nn.GELU()))
109
+
110
+ self.scratch = _make_scratch(
111
+ out_channels,
112
+ features,
113
+ groups=1,
114
+ expand=False,
115
+ )
116
+
117
+ self.scratch.stem_transpose = None
118
+
119
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn, groups=groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
120
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn, groups=groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
121
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn, groups=groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
122
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn, has_residual=False, groups=groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
123
+
124
+ head_features_1 = features
125
+ head_features_2 = 32
126
+
127
+
128
+
129
+
130
+ if not self.feature_only:
131
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
132
+ conv2_in_channels = head_features_1 // 2 + 3 * int(self.add_rgb)
133
+
134
+ if dpt_layer_norm:
135
+ self.scratch.output_conv2 = nn.Sequential(
136
+ ChannelLayerNorm(conv2_in_channels),
137
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
138
+ nn.ReLU(True),
139
+ ChannelLayerNorm(head_features_2),
140
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
141
+ # nn.ReLU(True),
142
+ # nn.Identity(),
143
+ )
144
+ else:
145
+ self.scratch.output_conv2 = nn.Sequential(
146
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
147
+ nn.ReLU(True),
148
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
149
+ # nn.ReLU(True),
150
+ # nn.Identity(),
151
+ )
152
+ else:
153
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
154
+
155
+
156
+
157
+ if load_da_str is not None:
158
+ from off3d.utils.train_utils import remove_if_not_match
159
+
160
+ da_path = os.path.join(torch.hub.get_dir(), load_da_str)
161
+ da_model = torch.load(da_path)
162
+ to_load_dict = {}
163
+ for k in da_model.keys():
164
+ if "depth_head" in k:
165
+ to_load_dict[k.replace("depth_head.", "")] = da_model[k]
166
+ all_keys = list(to_load_dict.keys())
167
+ model_state_dict = self.state_dict()
168
+ for cur_key in all_keys:
169
+ to_load_dict = remove_if_not_match(model_state_dict, to_load_dict, cur_key)
170
+
171
+ missing, unexpected = self.load_state_dict(to_load_dict, strict=False)
172
+
173
+ print("Missing keys in DPT head: ", missing)
174
+ print("Unexpected keys in DPT head: ", unexpected)
175
+ for layer in self.scratch.output_conv2:
176
+ if isinstance(layer, (nn.Conv2d, nn.Linear)):
177
+ layer.weight.data *= 0.1
178
+ layer.bias.data *= 0.1
179
+
180
+
181
+
182
+
183
+
184
+ def forward(self, aggregated_tokens_list, batch, patch_start_idx):
185
+
186
+ B, _, _, H, W = batch["images"].shape
187
+ S = aggregated_tokens_list[0].shape[1]
188
+
189
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
190
+
191
+ # TODO use rgb as input for the DPT head
192
+
193
+ out = []
194
+
195
+ dpt_idx = 0
196
+
197
+ for layer_idx in self.intermediate_layer_idx:
198
+ if self.use_clstoken:
199
+ raise NotImplementedError("CLS token is not supported for DPT head Now")
200
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
201
+ x = x.view(B*S, -1, x.shape[-1])
202
+
203
+ if self.shared_norm:
204
+ x = self.norm(x)
205
+ else:
206
+ x = self.norm[dpt_idx](x)
207
+
208
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
209
+
210
+ if self.head_use_checkpoint:
211
+ # e.g., from Bx2048xpatch_h*patch_w to Bx256xpatch_h*patch_w
212
+ x = torch.utils.checkpoint.checkpoint(self.projects[dpt_idx], x, use_reentrant=False)
213
+ if self.pos_embed:
214
+ x = self._apply_pos_embed(x, W, H)
215
+ x = torch.utils.checkpoint.checkpoint(self.resize_layers[dpt_idx], x, use_reentrant=False)
216
+ else:
217
+ x = self.projects[dpt_idx](x)
218
+ if self.pos_embed:
219
+ x = self._apply_pos_embed(x, W, H)
220
+ x = self.resize_layers[dpt_idx](x)
221
+
222
+ out.append(x)
223
+ dpt_idx += 1
224
+
225
+ if self.head_use_checkpoint:
226
+ out = torch.utils.checkpoint.checkpoint(self.scratch_forward, out, use_reentrant=False)
227
+ else:
228
+ out = self.scratch_forward(out)
229
+
230
+ # out = F.interpolate(out, (int(patch_h * self.patch_size), int(patch_w * self.patch_size)), mode="bilinear", align_corners=True)
231
+ out = custom_interpolate(out, (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), mode="bilinear", align_corners=True)
232
+
233
+ if self.pos_embed:
234
+ out = self._apply_pos_embed(out, W, H)
235
+
236
+ if self.feature_only:
237
+ return out
238
+
239
+
240
+ if self.add_rgb:
241
+ # NOTE batch["images"] is in the range of [0, 1]
242
+ out = torch.cat([out, batch["images"].view(B*S, 3, H, W).clip(0, 1)], dim=1)
243
+
244
+
245
+ if self.head_use_checkpoint:
246
+ out = torch.utils.checkpoint.checkpoint(self.scratch.output_conv2, out, use_reentrant=False)
247
+ else:
248
+ out = self.scratch.output_conv2(out)
249
+
250
+ preds, conf = activate_head(out, normalize_act=self.normalize_act, normalize_act_conf=self.normalize_act_conf)
251
+
252
+ # back to B, S
253
+ # B, S, H, W, 3
254
+ preds = preds.view(B, S, *preds.shape[1:])
255
+ # B, S, H, W
256
+ conf = conf.view(B, S, *conf.shape[1:])
257
+
258
+ return preds, conf
259
+
260
+
261
+ def _apply_pos_embed(self, x, W, H, ratio=0.1):
262
+ """Apply positional embedding to the input tensor."""
263
+ patch_w = x.shape[-1]
264
+ patch_h = x.shape[-2]
265
+
266
+ pos_embed = normalized_view_plane_uv(patch_w, patch_h, aspect_ratio=W/H, dtype=x.dtype, device=x.device)
267
+
268
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
269
+ pos_embed = pos_embed * ratio
270
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
271
+ return x + pos_embed
272
+
273
+
274
+ def scratch_forward(self, out):
275
+ layer_1, layer_2, layer_3, layer_4 = out
276
+
277
+
278
+ layer_1_rn = self.scratch.layer1_rn(layer_1) # layer_1:[32, 256, 148, 148]
279
+ layer_2_rn = self.scratch.layer2_rn(layer_2) # layer_2:[32, 512, 74, 74]
280
+ layer_3_rn = self.scratch.layer3_rn(layer_3) # layer_3:[32, 1024, 37, 37]
281
+ layer_4_rn = self.scratch.layer4_rn(layer_4) # layer_4:[32, 1024, 19, 19]
282
+
283
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
284
+ del layer_4_rn, layer_4
285
+
286
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
287
+ del layer_3_rn, layer_3
288
+
289
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
290
+ del layer_2_rn, layer_2
291
+
292
+ out = self.scratch.refinenet1(out, layer_1_rn)
293
+ del layer_1_rn, layer_1
294
+
295
+ out = self.scratch.output_conv1(out)
296
+ return out
297
+
298
+
299
+
300
+
301
+
302
+ ################################################################################
303
+
304
+ # Modules
305
+
306
+
307
+
308
+ def _make_fusion_block(features, use_bn, size=None, has_residual=True, groups=1, shallow_conv=False, dpt_layer_norm=False):
309
+ return FeatureFusionBlock(
310
+ features,
311
+ nn.ReLU(True),
312
+ deconv=False,
313
+ bn=use_bn,
314
+ expand=False,
315
+ align_corners=True,
316
+ size=size,
317
+ has_residual=has_residual,
318
+ groups=groups,
319
+ shallow_conv=shallow_conv,
320
+ dpt_layer_norm=dpt_layer_norm,
321
+ )
322
+
323
+
324
+
325
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
326
+ scratch = nn.Module()
327
+
328
+ out_shape1 = out_shape
329
+ out_shape2 = out_shape
330
+ out_shape3 = out_shape
331
+ if len(in_shape) >= 4:
332
+ out_shape4 = out_shape
333
+
334
+ if expand:
335
+ out_shape1 = out_shape
336
+ out_shape2 = out_shape * 2
337
+ out_shape3 = out_shape * 4
338
+ if len(in_shape) >= 4:
339
+ out_shape4 = out_shape * 8
340
+
341
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
342
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
343
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
344
+ if len(in_shape) >= 4:
345
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
346
+
347
+ return scratch
348
+
349
+
350
+
351
+
352
+ class ResidualConvUnit(nn.Module):
353
+ """Residual convolution module.
354
+ """
355
+
356
+ def __init__(self, features, activation, bn, groups=1, shallow_conv=False, dpt_layer_norm=False):
357
+ """Init.
358
+
359
+ Args:
360
+ features (int): number of features
361
+ """
362
+ super().__init__()
363
+
364
+ self.bn = bn
365
+
366
+ self.groups=groups
367
+
368
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
369
+
370
+ self.shallow_conv = shallow_conv
371
+ if not self.shallow_conv:
372
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
373
+
374
+ # if self.bn == True:
375
+ # self.bn1 = nn.BatchNorm2d(features)
376
+ # self.bn2 = nn.BatchNorm2d(features)
377
+ # elif dpt_layer_norm == :
378
+
379
+ if dpt_layer_norm:
380
+ self.norm1 = ChannelLayerNorm(features)
381
+ self.norm2 = ChannelLayerNorm(features)
382
+ else:
383
+ self.norm1 = None
384
+ self.norm2 = None
385
+
386
+ self.activation = activation
387
+
388
+ self.skip_add = nn.quantized.FloatFunctional()
389
+
390
+ def forward(self, x):
391
+ """Forward pass.
392
+
393
+ Args:
394
+ x (tensor): input
395
+
396
+ Returns:
397
+ tensor: output
398
+ """
399
+
400
+ out = self.activation(x)
401
+ out = self.conv1(out)
402
+ if self.norm1 is not None:
403
+ out = self.norm1(out)
404
+
405
+ if not self.shallow_conv:
406
+ out = self.activation(out)
407
+ out = self.conv2(out)
408
+ if self.norm2 is not None:
409
+ out = self.norm2(out)
410
+
411
+ # if self.groups > 1:
412
+ # out = self.conv_merge(out)
413
+
414
+ return self.skip_add.add(out, x)
415
+
416
+
417
+ class FeatureFusionBlock(nn.Module):
418
+ """Feature fusion block.
419
+ """
420
+
421
+ def __init__(
422
+ self,
423
+ features,
424
+ activation,
425
+ deconv=False,
426
+ bn=False,
427
+ expand=False,
428
+ align_corners=True,
429
+ size=None,
430
+ has_residual=True,
431
+ groups=1,
432
+ shallow_conv=False,
433
+ dpt_layer_norm=False,
434
+ ):
435
+ """Init.
436
+
437
+ Args:
438
+ features (int): number of features
439
+ """
440
+ super(FeatureFusionBlock, self).__init__()
441
+
442
+ self.deconv = deconv
443
+ self.align_corners = align_corners
444
+
445
+ self.groups=groups
446
+
447
+ self.expand = expand
448
+ out_features = features
449
+ if self.expand == True:
450
+ out_features = features // 2
451
+
452
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups)
453
+
454
+ if has_residual:
455
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
456
+
457
+ self.has_residual = has_residual
458
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
459
+
460
+ self.skip_add = nn.quantized.FloatFunctional()
461
+
462
+ self.size=size
463
+
464
+ def forward(self, *xs, size=None):
465
+ """Forward pass.
466
+
467
+ Returns:
468
+ tensor: output
469
+ """
470
+ output = xs[0]
471
+
472
+ if self.has_residual:
473
+ res = self.resConfUnit1(xs[1])
474
+ output = self.skip_add.add(output, res)
475
+
476
+ output = self.resConfUnit2(output)
477
+
478
+ if (size is None) and (self.size is None):
479
+ modifier = {"scale_factor": 2}
480
+ elif size is None:
481
+ modifier = {"size": self.size}
482
+ else:
483
+ modifier = {"size": size}
484
+
485
+ # output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
486
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
487
+
488
+ output = self.out_conv(output)
489
+
490
+ return output
491
+
492
+
493
+
494
+ def custom_interpolate(x, size=None, scale_factor=None, mode="bilinear", align_corners=True):
495
+ if size is None:
496
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
497
+ INT_MAX = 1610612736
498
+
499
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
500
+
501
+ if input_elements > INT_MAX:
502
+ # Split x into chunks along the batch dimension
503
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
504
+ interpolated_chunks = [nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks]
505
+ x = torch.cat(interpolated_chunks, dim=0)
506
+ return x.contiguous()
507
+ else:
508
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
509
+
510
+
511
+ class ChannelLayerNorm(nn.Module):
512
+ def __init__(self, num_channels):
513
+ super().__init__()
514
+ self.ln = nn.LayerNorm(num_channels)
515
+
516
+ def forward(self, x):
517
+ # x: [N, C, H, W]
518
+ x = x.permute(0, 2, 3, 1) # -> [N, H, W, C]
519
+ x = self.ln(x) # now LN sees 'C' as the last dimension
520
+ x = x.permute(0, 3, 1, 2) # -> [N, C, H, W]
521
+ return x
vggt/heads/head_act.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+
4
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
5
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
6
+ #
7
+ # --------------------------------------------------------
8
+ # post process function for all heads: extract 3D points/confidence from output
9
+ # --------------------------------------------------------
10
+ import torch
11
+ import torch.nn.functional as F
12
+
13
+
14
+
15
+
16
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
17
+ T = pred_pose_enc[..., :3]
18
+ quat = pred_pose_enc[..., 3:7]
19
+ fl = pred_pose_enc[..., 7:] # or fov
20
+
21
+ T = base_pose_act(T, trans_act)
22
+ quat = base_pose_act(quat, quat_act)
23
+ fl = base_pose_act(fl, fl_act) # or fov
24
+
25
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
26
+
27
+ return pred_pose_enc
28
+
29
+
30
+ def base_pose_act(pose_enc, act_type="linear"):
31
+ if act_type == "linear":
32
+ return pose_enc
33
+ elif act_type == "inv_log":
34
+ return inverse_log_transform(pose_enc)
35
+ elif act_type == "exp":
36
+ return torch.exp(pose_enc)
37
+ elif act_type == "relu":
38
+ return F.relu(pose_enc)
39
+ else:
40
+ raise ValueError(f"Unknown act_type: {act_type}")
41
+
42
+
43
+
44
+ def activate_head(out, normalize_act="norm_exp", normalize_act_conf="expp1"):
45
+ """
46
+ """
47
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
48
+ fmap = out.permute(0, 2, 3, 1) # B,H,W, C expected
49
+
50
+ # Split into xyz (first C-1 channels) and confidence (last channel)
51
+ xyz = fmap[:, :, :, :-1]
52
+ conf = fmap[:, :, :, -1]
53
+
54
+ if normalize_act == "norm_exp":
55
+ # 1) distance d = ||xyz||
56
+ # 2) normalize xyz => xyz / d
57
+ # 3) multiply by torch.expm1(d)
58
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
59
+ xyz_normed = xyz / d
60
+ pts3d = xyz_normed * torch.expm1(d)
61
+ elif normalize_act == "norm":
62
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
63
+ elif normalize_act == "exp":
64
+ pts3d = torch.exp(xyz)
65
+ elif normalize_act == "relu":
66
+ pts3d = F.relu(xyz)
67
+ elif normalize_act == "inv_log":
68
+ pts3d = inverse_log_transform(xyz)
69
+ elif normalize_act == "xy_inv_log":
70
+ xy, z = xyz.split([2, 1], dim=-1)
71
+ z = inverse_log_transform(z)
72
+ pts3d = torch.cat([xy * z, z], dim=-1)
73
+ elif normalize_act == "sigmoid":
74
+ pts3d = torch.sigmoid(xyz)
75
+ elif normalize_act == "linear":
76
+ pts3d = xyz
77
+ else:
78
+ raise ValueError(f"Unknown normalize_act: {normalize_act}")
79
+
80
+ # reg_dense_conf for mode='exp', with vmin=1, vmax=inf
81
+ # => conf_out = 1 + e^(conf)
82
+ # (since clip(max=vmax - vmin) with vmax=inf basically doesn’t limit anything)
83
+ if normalize_act_conf == "expp1":
84
+ conf_out = 1 + conf.exp()
85
+ elif normalize_act_conf == "expp0":
86
+ conf_out = conf.exp()
87
+ elif normalize_act_conf == "sigmoid":
88
+ conf_out = torch.sigmoid(conf)
89
+ else:
90
+ raise ValueError(f"Unknown normalize_act_conf: {normalize_act_conf}")
91
+
92
+ # Final dictionary
93
+ return pts3d, conf_out
94
+
95
+
96
+ def inverse_log_transform(y):
97
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
vggt/heads/track_head.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # linear head implementation for DUST3R
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import random
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from .head_act import activate_head
14
+ from .utils import normalized_view_plane_uv, HarmonicEmbedding, position_grid_to_embed
15
+ from .dpt_head import DPTHead
16
+ from .match_head import MatchHead
17
+ from ..track_modules.base_track_predictor import BaseTrackerPredictor
18
+ from ..track_modules.base_track_predictor_v2 import BaseTrackerPredictorV2
19
+
20
+ EPS = 1e-6
21
+
22
+ def reduce_masked_mean(x, mask, dim=None, keepdim=False):
23
+ # x and mask are the same shape, or at least broadcastably so
24
+ # returns shape-1
25
+ # axis can be a list of axes
26
+ for a, b in zip(x.size(), mask.size()):
27
+ assert a == b # some shape mismatch!
28
+ prod = x * mask
29
+ if dim is None:
30
+ numer = torch.sum(prod)
31
+ denom = EPS + torch.sum(mask)
32
+ else:
33
+ numer = torch.sum(prod, dim=dim, keepdim=keepdim)
34
+ denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
35
+
36
+ mean = numer / denom
37
+ return mean
38
+
39
+ def balanced_ce_loss(pred, gt, valid=None):
40
+ """Balanced cross entropy loss.
41
+ pred: predicted scores
42
+ gt: binary ground truth
43
+ valid: validity mask
44
+ """
45
+ # pred and gt are the same shape
46
+ for a, b in zip(pred.size(), gt.size()):
47
+ assert a == b # some shape mismatch!
48
+ if valid is not None:
49
+ for a, b in zip(pred.size(), valid.size()):
50
+ assert a == b # some shape mismatch!
51
+ else:
52
+ valid = torch.ones_like(gt)
53
+
54
+ pos = (gt > 0.95).float()
55
+ neg = (gt < 0.05).float()
56
+
57
+ label = pos * 2.0 - 1.0
58
+ a = -label * pred
59
+ b = F.relu(a)
60
+ loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
61
+
62
+ pos_loss = reduce_masked_mean(loss, pos * valid)
63
+ neg_loss = reduce_masked_mean(loss, neg * valid)
64
+
65
+ balanced_loss = pos_loss + neg_loss
66
+
67
+ return balanced_loss, loss
68
+
69
+ def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8, vis_aware=False, huber=False, delta=10, vis_aware_w=0.1, **kwargs):
70
+ """Loss function defined over sequence of flow predictions"""
71
+ B, S, N, D = flow_gt.shape
72
+ assert D == 2
73
+ B, S1, N = vis.shape
74
+ B, S2, N = valids.shape
75
+ assert S == S1
76
+ assert S == S2
77
+ n_predictions = len(flow_preds)
78
+ flow_loss = 0.0
79
+
80
+ for i in range(n_predictions):
81
+ i_weight = gamma ** (n_predictions - i - 1)
82
+ flow_pred = flow_preds[i]
83
+
84
+ i_loss = (flow_pred - flow_gt).abs() # B, S, N, 2
85
+ i_loss = torch.mean(i_loss, dim=3) # B, S, N
86
+
87
+ # Combine valids and vis for per-frame valid masking.
88
+ combined_mask = torch.logical_and(valids, vis)
89
+
90
+ # valids * vis.float() # B, S, N
91
+
92
+ # vis_aware weighting. Apply BEFORE reduce_masked_mean
93
+
94
+ if vis_aware:
95
+ combined_mask = combined_mask.float() * (1.0 + vis_aware_w) # Add, don't add to the mask itself.
96
+ # combined_mask = torch.clamp(combined_mask, 0.0, 1.0) # No need to clamp.
97
+ # Apply the mask *before* taking the mean.
98
+ # i_loss = i_loss * combined_mask
99
+ # flow_loss += i_weight * i_loss.mean()
100
+ flow_loss += i_weight * reduce_masked_mean(i_loss, combined_mask)
101
+ else:
102
+ if combined_mask.numel() > 10:
103
+ # flow_loss += i_weight * i_loss.mean()
104
+ i_loss = i_loss[combined_mask]
105
+ flow_loss += i_weight * i_loss.mean()
106
+ else:
107
+ flow_loss += 0
108
+
109
+ # # Handle the case where no points are valid.
110
+ # if combined_mask.sum() > 0:
111
+ # flow_loss += i_weight * reduce_masked_mean(i_loss, combined_mask) # Pass combined_mask
112
+ # else: No valid points, so this term contributes 0 to the loss.
113
+ # flow_loss += 0. (This is implicit)
114
+
115
+ # Avoid division by zero if n_predictions is 0 (though it shouldn't be).
116
+ if n_predictions > 0:
117
+ flow_loss = flow_loss / n_predictions
118
+
119
+ return flow_loss
120
+
121
+ class TrackHead(nn.Module):
122
+ """
123
+ Track head that uses DPT/Match head to process tokens and BaseTrackerPredictor for tracking.
124
+ """
125
+ def __init__(self,
126
+ dim_in,
127
+ patch_size=16,
128
+ features=128,
129
+ feature_extractor_type="dpt", # or "match"
130
+ train_query_points=128,
131
+ feature_extractor_kwargs={},
132
+ tracker_kwargs={},
133
+ loss_kwargs={},
134
+ iters=4,
135
+ use_base_tracker_v2=False,
136
+ predict_conf=False,
137
+ random_query_points = None,
138
+ **kwargs):
139
+ super().__init__()
140
+
141
+ self.patch_size = patch_size
142
+ self.feature_extractor_type = feature_extractor_type
143
+ self.train_query_points = train_query_points
144
+ self.random_query_points = random_query_points
145
+
146
+ # Initialize feature extractor (DPT or Match head)
147
+ if feature_extractor_type == "dpt":
148
+ self.feature_extractor = DPTHead(
149
+ dim_in=dim_in,
150
+ patch_size=patch_size,
151
+ features=features,
152
+ feature_only=True, # Only output features, no activation
153
+ **feature_extractor_kwargs
154
+ )
155
+ elif feature_extractor_type == "match":
156
+ raise NotImplementedError("Match head is not implemented for track head")
157
+ self.feature_extractor = MatchHead(
158
+ dim_in=dim_in,
159
+ patch_size=patch_size,
160
+ features=features,
161
+ **feature_extractor_kwargs
162
+ )
163
+ else:
164
+ raise ValueError(f"Unknown feature_extractor_type: {feature_extractor_type}")
165
+
166
+ # Initialize tracker
167
+ if use_base_tracker_v2:
168
+ self.tracker = BaseTrackerPredictorV2(
169
+ latent_dim=features, # Match the output_dim of feature extractor
170
+ predict_conf=predict_conf,
171
+ **tracker_kwargs
172
+ )
173
+ else:
174
+ self.tracker = BaseTrackerPredictor(
175
+ latent_dim=features, # Match the output_dim of feature extractor
176
+ predict_conf=predict_conf,
177
+ **tracker_kwargs
178
+ )
179
+
180
+ self.loss_kwargs = loss_kwargs
181
+ self.iters = iters
182
+
183
+
184
+ def _compute_losses(self, coord_preds, vis_scores, conf_scores, batch):
185
+ """Compute tracking losses using sequence_loss"""
186
+ gt_tracks = batch["tracks"] # B, S, N, 2
187
+ gt_track_vis_mask = batch["track_vis_mask"] # B, S, N
188
+
189
+ # if self.training and hasattr(self, "train_query_points"):
190
+ train_query_points = coord_preds[-1].shape[2]
191
+ gt_tracks = gt_tracks[:, :, :train_query_points]
192
+ gt_track_vis_mask = gt_track_vis_mask[:, :, :train_query_points]
193
+
194
+ # Create validity mask that filters out tracks not visible in first frame
195
+ valids = torch.ones_like(gt_track_vis_mask)
196
+ mask = gt_track_vis_mask[:, 0, :] == True
197
+ valids = valids * mask.unsqueeze(1)
198
+
199
+ # Compute tracking loss using sequence_loss
200
+ track_loss = sequence_loss(
201
+ flow_preds=coord_preds,
202
+ flow_gt=gt_tracks,
203
+ vis=gt_track_vis_mask,
204
+ valids=valids,
205
+ **self.loss_kwargs
206
+ )
207
+
208
+ vis_loss = F.binary_cross_entropy_with_logits(vis_scores[valids], gt_track_vis_mask[valids].float())
209
+ # within 3 pixels
210
+ if conf_scores is not None:
211
+ gt_conf_mask = (gt_tracks - coord_preds[-1]).norm(dim=-1) < 3
212
+ conf_loss = F.binary_cross_entropy_with_logits(conf_scores[valids], gt_conf_mask[valids].float())
213
+ else:
214
+ conf_loss = 0
215
+
216
+ return track_loss, vis_loss, conf_loss
217
+
218
+ def forward(self, aggregated_tokens_list, batch, patch_start_idx):
219
+ B, S, _, H, W = batch["images"].shape
220
+
221
+ gt_tracks = batch["tracks"] # B, S, N, 2
222
+ # gt_track_vis_mask = batch["track_vis_mask"] # B, S, N
223
+
224
+ # Extract features using DPT/Match head
225
+ if self.feature_extractor_type == "dpt":
226
+ feature_maps = self.feature_extractor(aggregated_tokens_list, batch, patch_start_idx)
227
+ else: # match head
228
+ feature_maps = self.feature_extractor(aggregated_tokens_list, batch, patch_start_idx)["descriptor"]
229
+
230
+ feature_maps = feature_maps.view(B, S, *feature_maps.shape[1:]).clone()
231
+ # Get query points from batch
232
+
233
+ query_points = gt_tracks[:, 0] # Use first frame's points as query
234
+
235
+ if self.training:
236
+ if self.random_query_points is not None:
237
+ min_val = self.random_query_points[0]
238
+ max_val = self.random_query_points[1]
239
+ mu = max_val # Mean centered at the upper bound
240
+ sigma = (max_val - min_val) / 2.71 # Standard deviation, exp
241
+ train_query_points = int(random.gauss(mu, sigma))
242
+ train_query_points = max(min(train_query_points, max_val), min_val) # Clamp to ensure value is within range
243
+ else:
244
+ train_query_points = self.train_query_points
245
+ query_points = query_points[:, :train_query_points]
246
+
247
+ # Predict tracks using BaseTrackerPredictor
248
+ # coord_preds: a list of B, S, N, 2
249
+ # vis_scores: B, S, N
250
+ coord_preds, vis_scores, conf_scores = self.tracker(
251
+ query_points=query_points,
252
+ fmaps=feature_maps,
253
+ iters=self.iters,
254
+ )
255
+
256
+ # Calculate losses if in training mode
257
+ track_loss, vis_loss, conf_loss = self._compute_losses(coord_preds, vis_scores, conf_scores, batch)
258
+
259
+ loss_dict = {
260
+ "loss_track": track_loss,
261
+ "loss_vis": vis_loss,
262
+ "loss_track_conf": conf_loss,
263
+ "last_track_pred": coord_preds[-1],
264
+ }
265
+ return loss_dict
266
+
267
+
vggt/heads/utils.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Optional
5
+
6
+
7
+
8
+
9
+ def make_sincos_pos_embed(
10
+ embed_dim: int, pos: torch.Tensor, omega_0: float = 100
11
+ ) -> torch.Tensor:
12
+ """
13
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
14
+
15
+ Args:
16
+ - embed_dim: The embedding dimension.
17
+ - pos: The position to generate the embedding from.
18
+
19
+ Returns:
20
+ - emb: The generated 1D positional embedding.
21
+ """
22
+ assert embed_dim % 2 == 0
23
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
24
+ omega /= embed_dim / 2.0
25
+ omega = 1.0 / omega_0**omega # (D/2,)
26
+
27
+ pos = pos.reshape(-1) # (M,)
28
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
29
+
30
+ emb_sin = torch.sin(out) # (M, D/2)
31
+ emb_cos = torch.cos(out) # (M, D/2)
32
+
33
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
34
+ return emb.float()
35
+
36
+
37
+
38
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
39
+ """
40
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
41
+
42
+ Args:
43
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
44
+ embed_dim: Output channel dimension for embeddings
45
+
46
+ Returns:
47
+ Tensor of shape (H, W, embed_dim) with positional embeddings
48
+ """
49
+ H, W, grid_dim = pos_grid.shape
50
+ assert grid_dim == 2
51
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
52
+
53
+ # Process x and y coordinates separately
54
+ emb_x = make_sincos_pos_embed(embed_dim//2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
55
+ emb_y = make_sincos_pos_embed(embed_dim//2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
56
+
57
+ # Combine and reshape
58
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
59
+
60
+ return emb.view(H, W, embed_dim) # [H, W, D]
61
+
62
+
63
+ class HarmonicEmbedding(torch.nn.Module):
64
+ def __init__(
65
+ self,
66
+ n_harmonic_functions: int = 6,
67
+ omega_0: float = 1.0,
68
+ logspace: bool = True,
69
+ append_input: bool = True,
70
+ ) -> None:
71
+ """
72
+ The harmonic embedding layer supports the classical
73
+ Nerf positional encoding described in
74
+ `NeRF <https://arxiv.org/abs/2003.08934>`_
75
+ and the integrated position encoding in
76
+ `MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
77
+
78
+ During the inference you can provide the extra argument `diag_cov`.
79
+
80
+ If `diag_cov is None`, it converts
81
+ rays parametrized with a `ray_bundle` to 3D points by
82
+ extending each ray according to the corresponding length.
83
+ Then it converts each feature
84
+ (i.e. vector along the last dimension) in `x`
85
+ into a series of harmonic features `embedding`,
86
+ where for each i in range(dim) the following are present
87
+ in embedding[...]::
88
+
89
+ [
90
+ sin(f_1*x[..., i]),
91
+ sin(f_2*x[..., i]),
92
+ ...
93
+ sin(f_N * x[..., i]),
94
+ cos(f_1*x[..., i]),
95
+ cos(f_2*x[..., i]),
96
+ ...
97
+ cos(f_N * x[..., i]),
98
+ x[..., i], # only present if append_input is True.
99
+ ]
100
+
101
+ where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
102
+ denoting the i-th frequency of the harmonic embedding.
103
+
104
+
105
+ If `diag_cov is not None`, it approximates
106
+ conical frustums following a ray bundle as gaussians,
107
+ defined by x, the means of the gaussians and diag_cov,
108
+ the diagonal covariances.
109
+ Then it converts each gaussian
110
+ into a series of harmonic features `embedding`,
111
+ where for each i in range(dim) the following are present
112
+ in embedding[...]::
113
+
114
+ [
115
+ sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
116
+ sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
117
+ ...
118
+ sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
119
+ cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
120
+ cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
121
+ ...
122
+ cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
123
+ x[..., i], # only present if append_input is True.
124
+ ]
125
+
126
+ where N equals `n_harmonic_functions-1`, and f_i is a scalar
127
+ denoting the i-th frequency of the harmonic embedding.
128
+
129
+ If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
130
+ powers of 2:
131
+ `f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
132
+
133
+ If `logspace==False`, frequencies are linearly spaced between
134
+ `1.0` and `2**(n_harmonic_functions-1)`:
135
+ `f_1, ..., f_N = torch.linspace(
136
+ 1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
137
+ )`
138
+
139
+ Note that `x` is also premultiplied by the base frequency `omega_0`
140
+ before evaluating the harmonic functions.
141
+
142
+ Args:
143
+ n_harmonic_functions: int, number of harmonic
144
+ features
145
+ omega_0: float, base frequency
146
+ logspace: bool, Whether to space the frequencies in
147
+ logspace or linear space
148
+ append_input: bool, whether to concat the original
149
+ input to the harmonic embedding. If true the
150
+ output is of the form (embed.sin(), embed.cos(), x)
151
+ """
152
+ super().__init__()
153
+
154
+ if logspace:
155
+ frequencies = 2.0 ** torch.arange(
156
+ n_harmonic_functions, dtype=torch.float32
157
+ )
158
+ else:
159
+ frequencies = torch.linspace(
160
+ 1.0,
161
+ 2.0 ** (n_harmonic_functions - 1),
162
+ n_harmonic_functions,
163
+ dtype=torch.float32,
164
+ )
165
+
166
+ self.register_buffer(
167
+ "_frequencies", frequencies * omega_0, persistent=False
168
+ )
169
+ self.register_buffer(
170
+ "_zero_half_pi",
171
+ torch.tensor([0.0, 0.5 * torch.pi]),
172
+ persistent=False,
173
+ )
174
+ self.append_input = append_input
175
+
176
+ def forward(
177
+ self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
178
+ ) -> torch.Tensor:
179
+ """
180
+ Args:
181
+ x: tensor of shape [..., dim]
182
+ diag_cov: An optional tensor of shape `(..., dim)`
183
+ representing the diagonal covariance matrices of our Gaussians, joined with x
184
+ as means of the Gaussians.
185
+
186
+ Returns:
187
+ embedding: a harmonic embedding of `x` of shape
188
+ [..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
189
+ """
190
+ # [..., dim, n_harmonic_functions]
191
+ embed = x[..., None] * self._frequencies
192
+ # [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
193
+ embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
194
+ # Use the trig identity cos(x) = sin(x + pi/2)
195
+ # and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
196
+ embed = embed.sin()
197
+ if diag_cov is not None:
198
+ x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
199
+ exp_var = torch.exp(-0.5 * x_var)
200
+ # [..., 2, dim, n_harmonic_functions]
201
+ embed = embed * exp_var[..., None, :, :]
202
+
203
+ embed = embed.reshape(*x.shape[:-1], -1)
204
+
205
+ if self.append_input:
206
+ return torch.cat([embed, x], dim=-1)
207
+ return embed
208
+
209
+ @staticmethod
210
+ def get_output_dim_static(
211
+ input_dims: int, n_harmonic_functions: int, append_input: bool
212
+ ) -> int:
213
+ """
214
+ Utility to help predict the shape of the output of `forward`.
215
+
216
+ Args:
217
+ input_dims: length of the last dimension of the input tensor
218
+ n_harmonic_functions: number of embedding frequencies
219
+ append_input: whether or not to concat the original
220
+ input to the harmonic embedding
221
+ Returns:
222
+ int: the length of the last dimension of the output tensor
223
+ """
224
+ return input_dims * (2 * n_harmonic_functions + int(append_input))
225
+
226
+ def get_output_dim(self, input_dims: int = 3) -> int:
227
+ """
228
+ Same as above. The default for input_dims is 3 for 3D applications
229
+ which use harmonic embedding for positional encoding,
230
+ so the input might be xyz.
231
+ """
232
+ return self.get_output_dim_static(
233
+ input_dims, len(self._frequencies), self.append_input
234
+ )
235
+
236
+
237
+
238
+
239
+ class PoseEmbedding(nn.Module):
240
+ def __init__(self, target_dim, n_harmonic_functions=10, append_input=True):
241
+ super().__init__()
242
+
243
+ self._emb_pose = HarmonicEmbedding(
244
+ n_harmonic_functions=n_harmonic_functions, append_input=append_input
245
+ )
246
+
247
+ self.out_dim = self._emb_pose.get_output_dim(target_dim)
248
+
249
+ def forward(self, pose_encoding):
250
+ e_pose_encoding = self._emb_pose(pose_encoding)
251
+ return e_pose_encoding
252
+
253
+
254
+
255
+
256
+ def random_mask_single_patch_vectorized(images, patch_size=(16, 16)):
257
+ """
258
+ Randomly masks a single patch in a batch of images using fully vectorized operations.
259
+ :param images: Tensor of shape [B, 3, H, W]
260
+ :param patch_size: Tuple (ph, pw), size of the patch to mask
261
+ """
262
+ B, C, H, W = images.shape
263
+ ph, pw = patch_size
264
+
265
+ # Generate random positions for the top-left corner of the patch
266
+ x_positions = torch.randint(0, W - pw, (B, 1, 1))
267
+ y_positions = torch.randint(0, H - ph, (B, 1, 1))
268
+
269
+ # Compute patch grid indices
270
+ patch_x = torch.arange(pw).reshape(1, 1, pw)
271
+ patch_y = torch.arange(ph).reshape(1, ph, 1)
272
+
273
+ # Broadcast patch indices to each position
274
+ x_indices = x_positions + patch_x
275
+ y_indices = y_positions + patch_y
276
+
277
+ # Expand the indices to cover all channels and all images in the batch
278
+ x_indices = x_indices.expand(B, ph, pw)
279
+ y_indices = y_indices.expand(B, ph, pw)
280
+
281
+ # Flatten the indices to apply the mask using advanced indexing
282
+ batch_indices = torch.arange(B).unsqueeze(-1).expand(B, ph * pw)
283
+ x_indices = x_indices.reshape(B, ph * pw)
284
+ y_indices = y_indices.reshape(B, ph * pw)
285
+
286
+ # Create a mask initialized to one and apply zero at the indices
287
+ mask = torch.ones_like(images)
288
+ mask[batch_indices, :, y_indices, x_indices] = 0
289
+
290
+ # Apply mask to images
291
+ return images * mask
292
+
293
+
294
+
295
+
296
+ def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
297
+ # borrowed from https://github.com/microsoft/moge
298
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
299
+ if aspect_ratio is None:
300
+ aspect_ratio = width / height
301
+
302
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
303
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
304
+
305
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
306
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
307
+ u, v = torch.meshgrid(u, v, indexing='xy')
308
+ uv = torch.stack([u, v], dim=-1)
309
+ return uv
vggt/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
vggt/layers/attention.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ logger = logging.getLogger("dinov2")
20
+
21
+
22
+ # XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
23
+ # try:
24
+ # if XFORMERS_ENABLED:
25
+ # from xformers.ops import memory_efficient_attention, unbind
26
+
27
+ # XFORMERS_AVAILABLE = True
28
+ # warnings.warn("xFormers is available (Attention)")
29
+ # else:
30
+ # warnings.warn("xFormers is disabled (Attention)")
31
+ # raise ImportError
32
+ # except ImportError:
33
+ # XFORMERS_AVAILABLE = False
34
+ # warnings.warn("xFormers is not available (Attention)")
35
+
36
+ XFORMERS_AVAILABLE = False
37
+
38
+
39
+ class Attention(nn.Module):
40
+ def __init__(
41
+ self,
42
+ dim: int,
43
+ num_heads: int = 8,
44
+ qkv_bias: bool = True,
45
+ qk_norm: bool = False,
46
+ attn_drop: float = 0.,
47
+ proj_drop: float = 0.,
48
+ proj_bias: bool = True,
49
+ norm_layer: nn.Module = nn.LayerNorm,
50
+ fused_attn: bool = True,
51
+ rope = None,
52
+ ) -> None:
53
+ super().__init__()
54
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
55
+ self.num_heads = num_heads
56
+ self.head_dim = dim // num_heads
57
+ self.scale = self.head_dim ** -0.5
58
+ self.fused_attn = fused_attn
59
+
60
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
61
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
62
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
63
+ self.attn_drop = nn.Dropout(attn_drop)
64
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
65
+ self.proj_drop = nn.Dropout(proj_drop)
66
+ self.rope = rope
67
+
68
+ def forward(self, x: Tensor, pos=None) -> Tensor:
69
+ B, N, C = x.shape
70
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
71
+ q, k, v = qkv.unbind(0)
72
+ q, k = self.q_norm(q), self.k_norm(k)
73
+
74
+ if self.rope is not None:
75
+ q = self.rope(q, pos)
76
+ k = self.rope(k, pos)
77
+
78
+ if self.fused_attn:
79
+ x = F.scaled_dot_product_attention(
80
+ q, k, v,
81
+ dropout_p=self.attn_drop.p if self.training else 0.,
82
+ )
83
+ else:
84
+ q = q * self.scale
85
+ attn = q @ k.transpose(-2, -1)
86
+ attn = attn.softmax(dim=-1)
87
+ attn = self.attn_drop(attn)
88
+ x = attn @ v
89
+
90
+ x = x.transpose(1, 2).reshape(B, N, C)
91
+ x = self.proj(x)
92
+ x = self.proj_drop(x)
93
+ return x
94
+
95
+
96
+
97
+
98
+ class MemEffAttention(Attention):
99
+ def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
100
+ assert pos is None
101
+ if not XFORMERS_AVAILABLE:
102
+ if attn_bias is not None:
103
+ raise AssertionError("xFormers is required for using nested tensors")
104
+ return super().forward(x)
105
+
106
+ B, N, C = x.shape
107
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
108
+
109
+ q, k, v = unbind(qkv, 2)
110
+
111
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
112
+ x = x.reshape([B, N, C])
113
+
114
+ x = self.proj(x)
115
+ x = self.proj_drop(x)
116
+ return x
vggt/layers/block.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ # XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ # try:
29
+ # if XFORMERS_ENABLED:
30
+ # from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ # XFORMERS_AVAILABLE = True
33
+ # warnings.warn("xFormers is available (Block)")
34
+ # else:
35
+ # warnings.warn("xFormers is disabled (Block)")
36
+ # raise ImportError
37
+ # except ImportError:
38
+ # XFORMERS_AVAILABLE = False
39
+
40
+ # warnings.warn("xFormers is not available (Block)")
41
+
42
+ XFORMERS_AVAILABLE = False
43
+
44
+ class Block(nn.Module):
45
+ def __init__(
46
+ self,
47
+ dim: int,
48
+ num_heads: int,
49
+ mlp_ratio: float = 4.0,
50
+ qkv_bias: bool = True,
51
+ qk_norm: bool = False,
52
+ proj_bias: bool = True,
53
+ ffn_bias: bool = True,
54
+ fused_attn: bool = True,
55
+ drop: float = 0.0,
56
+ attn_drop: float = 0.0,
57
+ init_values=None,
58
+ drop_path: float = 0.0,
59
+ act_layer: Callable[..., nn.Module] = nn.GELU,
60
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
61
+ attn_class: Callable[..., nn.Module] = Attention,
62
+ ffn_layer: Callable[..., nn.Module] = Mlp,
63
+ rope_freq: int = -1,
64
+ rope = None,
65
+ ) -> None:
66
+ super().__init__()
67
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
68
+ self.norm1 = norm_layer(dim)
69
+
70
+ self.attn = attn_class(
71
+ dim,
72
+ num_heads=num_heads,
73
+ qkv_bias=qkv_bias,
74
+ qk_norm=qk_norm,
75
+ proj_bias=proj_bias,
76
+ attn_drop=attn_drop,
77
+ proj_drop=drop,
78
+ fused_attn=fused_attn,
79
+ rope=rope,
80
+ )
81
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
82
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
83
+
84
+ self.norm2 = norm_layer(dim)
85
+ mlp_hidden_dim = int(dim * mlp_ratio)
86
+ self.mlp = ffn_layer(
87
+ in_features=dim,
88
+ hidden_features=mlp_hidden_dim,
89
+ act_layer=act_layer,
90
+ drop=drop,
91
+ bias=ffn_bias,
92
+ )
93
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
94
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
95
+
96
+ self.sample_drop_ratio = drop_path
97
+
98
+ def forward(self, x: Tensor, pos=None) -> Tensor:
99
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
100
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
101
+
102
+ def ffn_residual_func(x: Tensor) -> Tensor:
103
+ return self.ls2(self.mlp(self.norm2(x)))
104
+
105
+ if self.training and self.sample_drop_ratio > 0.1:
106
+ # the overhead is compensated only for a drop path rate larger than 0.1
107
+ x = drop_add_residual_stochastic_depth(
108
+ x,
109
+ pos=pos,
110
+ residual_func=attn_residual_func,
111
+ sample_drop_ratio=self.sample_drop_ratio,
112
+ )
113
+ x = drop_add_residual_stochastic_depth(
114
+ x,
115
+ residual_func=ffn_residual_func,
116
+ sample_drop_ratio=self.sample_drop_ratio,
117
+ )
118
+ elif self.training and self.sample_drop_ratio > 0.0:
119
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
120
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
121
+ else:
122
+ x = x + attn_residual_func(x, pos=pos)
123
+ x = x + ffn_residual_func(x)
124
+ return x
125
+
126
+
127
+ def drop_add_residual_stochastic_depth(
128
+ x: Tensor,
129
+ residual_func: Callable[[Tensor], Tensor],
130
+ sample_drop_ratio: float = 0.0,
131
+ pos = None,
132
+ ) -> Tensor:
133
+ # 1) extract subset using permutation
134
+ b, n, d = x.shape
135
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
136
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
137
+ x_subset = x[brange]
138
+
139
+ # 2) apply residual_func to get residual
140
+ if pos is not None:
141
+ pos = pos[brange]
142
+ residual = residual_func(x_subset, pos=pos)
143
+ else:
144
+ residual = residual_func(x_subset)
145
+
146
+ x_flat = x.flatten(1)
147
+ residual = residual.flatten(1)
148
+
149
+ residual_scale_factor = b / sample_subset_size
150
+
151
+ # 3) add the residual
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ return x_plus_residual.view_as(x)
154
+
155
+
156
+ def get_branges_scales(x, sample_drop_ratio=0.0):
157
+ b, n, d = x.shape
158
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
159
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
160
+ residual_scale_factor = b / sample_subset_size
161
+ return brange, residual_scale_factor
162
+
163
+
164
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
165
+ if scaling_vector is None:
166
+ x_flat = x.flatten(1)
167
+ residual = residual.flatten(1)
168
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
169
+ else:
170
+ x_plus_residual = scaled_index_add(
171
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
172
+ )
173
+ return x_plus_residual
174
+
175
+
176
+ attn_bias_cache: Dict[Tuple, Any] = {}
177
+
178
+
179
+ def get_attn_bias_and_cat(x_list, branges=None):
180
+ """
181
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
182
+ """
183
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
184
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
185
+ if all_shapes not in attn_bias_cache.keys():
186
+ seqlens = []
187
+ for b, x in zip(batch_sizes, x_list):
188
+ for _ in range(b):
189
+ seqlens.append(x.shape[1])
190
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
191
+ attn_bias._batch_sizes = batch_sizes
192
+ attn_bias_cache[all_shapes] = attn_bias
193
+
194
+ if branges is not None:
195
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
196
+ else:
197
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
198
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
199
+
200
+ return attn_bias_cache[all_shapes], cat_tensors
201
+
202
+
203
+ def drop_add_residual_stochastic_depth_list(
204
+ x_list: List[Tensor],
205
+ residual_func: Callable[[Tensor, Any], Tensor],
206
+ sample_drop_ratio: float = 0.0,
207
+ scaling_vector=None,
208
+ ) -> Tensor:
209
+ # 1) generate random set of indices for dropping samples in the batch
210
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
211
+ branges = [s[0] for s in branges_scales]
212
+ residual_scale_factors = [s[1] for s in branges_scales]
213
+
214
+ # 2) get attention bias and index+concat the tensors
215
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
216
+
217
+ # 3) apply residual_func to get residual, and split the result
218
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
219
+
220
+ outputs = []
221
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
222
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
223
+ return outputs
224
+
225
+
226
+ class NestedTensorBlock(Block):
227
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
228
+ """
229
+ x_list contains a list of tensors to nest together and run
230
+ """
231
+ assert isinstance(self.attn, MemEffAttention)
232
+
233
+ if self.training and self.sample_drop_ratio > 0.0:
234
+
235
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
236
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
237
+
238
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
239
+ return self.mlp(self.norm2(x))
240
+
241
+ x_list = drop_add_residual_stochastic_depth_list(
242
+ x_list,
243
+ residual_func=attn_residual_func,
244
+ sample_drop_ratio=self.sample_drop_ratio,
245
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
246
+ )
247
+ x_list = drop_add_residual_stochastic_depth_list(
248
+ x_list,
249
+ residual_func=ffn_residual_func,
250
+ sample_drop_ratio=self.sample_drop_ratio,
251
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
252
+ )
253
+ return x_list
254
+ else:
255
+
256
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
257
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
258
+
259
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
260
+ return self.ls2(self.mlp(self.norm2(x)))
261
+
262
+ attn_bias, x = get_attn_bias_and_cat(x_list)
263
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
264
+ x = x + ffn_residual_func(x)
265
+ return attn_bias.split(x)
266
+
267
+ def forward(self, x_or_x_list):
268
+ if isinstance(x_or_x_list, Tensor):
269
+ return super().forward(x_or_x_list)
270
+ elif isinstance(x_or_x_list, list):
271
+ if not XFORMERS_AVAILABLE:
272
+ raise AssertionError("xFormers is required for using nested tensors")
273
+ return self.forward_nested(x_or_x_list)
274
+ else:
275
+ raise AssertionError
vggt/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
vggt/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
vggt/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
vggt/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
vggt/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
vggt/layers/rope.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ class PositionGetter(object):
6
+ """ return positions of patches """
7
+
8
+ # NOTE this can take a lot of memory when the patch size is variable
9
+
10
+ def __init__(self):
11
+ self.cache_positions = {}
12
+
13
+ def __call__(self, b, h, w, device):
14
+ if not (h,w) in self.cache_positions:
15
+ x = torch.arange(w, device=device)
16
+ y = torch.arange(h, device=device)
17
+ self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
18
+ pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
19
+ return pos
20
+
21
+
22
+ # --------------------------------------------------------
23
+ # 2D sine-cosine position embedding
24
+ # References:
25
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
26
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
27
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
28
+ # --------------------------------------------------------
29
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
30
+ """
31
+ grid_size: tuple (height, width) of the grid
32
+ return:
33
+ pos_embed: [grid_size[0]*grid_size[1], embed_dim] or [n_cls_token+grid_size[0]*grid_size[1], embed_dim] (w/ or w/o cls_token)
34
+ """
35
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
36
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
37
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
38
+ grid = np.stack(grid, axis=0)
39
+
40
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
41
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
42
+ if n_cls_token>0:
43
+ pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
44
+ return pos_embed
45
+
46
+
47
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
48
+ assert embed_dim % 2 == 0
49
+
50
+ # use half of dimensions to encode grid_h
51
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
52
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
53
+
54
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
55
+ return emb
56
+
57
+
58
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
59
+ """
60
+ embed_dim: output dimension for each position
61
+ pos: a list of positions to be encoded: size (M,)
62
+ out: (M, D)
63
+ """
64
+ assert embed_dim % 2 == 0
65
+ omega = np.arange(embed_dim // 2, dtype=float)
66
+ omega /= embed_dim / 2.
67
+ omega = 1. / 10000**omega # (D/2,)
68
+
69
+ pos = pos.reshape(-1) # (M,)
70
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
71
+
72
+ emb_sin = np.sin(out) # (M, D/2)
73
+ emb_cos = np.cos(out) # (M, D/2)
74
+
75
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
76
+ return emb
77
+
78
+
79
+ # --------------------------------------------------------
80
+ # Interpolate position embeddings for high-resolution
81
+ # References:
82
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
83
+ # DeiT: https://github.com/facebookresearch/deit
84
+ # --------------------------------------------------------
85
+ def interpolate_pos_embed(model, checkpoint_model):
86
+ keys = ['enc_pos_embed']+(['dec_pos_embed'] if hasattr(model,'dec_blocks') else [])
87
+ img_size = model.patch_embed.img_size
88
+ if isinstance(img_size,int): img_size = (img_size,img_size)
89
+ for k in keys:
90
+ if not k in checkpoint_model: continue
91
+ pos_embed_checkpoint = checkpoint_model[k]
92
+ embedding_size = pos_embed_checkpoint.shape[-1]
93
+ num_extra_tokens = 0 # no cls token
94
+ # height (== width) for the checkpoint position embedding
95
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
96
+ new_size = (img_size[0]//model.patch_embed.patch_size[0],img_size[1]//model.patch_embed.patch_size[1])
97
+ if orig_size != new_size[0] or orig_size != new_size[1]:
98
+ print("Position interpolate %s from %dx%d to %dx%d" % (k, orig_size, orig_size, new_size[0], new_size[1]))
99
+ extra_tokens = pos_embed_checkpoint[:num_extra_tokens,:]
100
+ pos_tokens = pos_embed_checkpoint[num_extra_tokens:,:]
101
+ pos_tokens = pos_tokens.reshape(1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
102
+ pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
103
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2).squeeze(0)
104
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
105
+ checkpoint_model[k] = new_pos_embed.squeeze(0)
106
+
107
+ #----------------------------------------------------------
108
+ # RoPE2D: RoPE implementation in 2D
109
+ #----------------------------------------------------------
110
+
111
+ # borrowed from https://github.com/naver/dust3r
112
+ # todo: replace with our official implementation
113
+
114
+ class RoPE2D(torch.nn.Module):
115
+ def __init__(self, freq=100.0, F0=1.0):
116
+ super().__init__()
117
+ self.base = freq
118
+ self.F0 = F0
119
+ self.cache = {}
120
+
121
+ def get_cos_sin(self, D, seq_len, device, dtype):
122
+ if (D,seq_len,device,dtype) not in self.cache:
123
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
124
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
125
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
126
+ freqs = torch.cat((freqs, freqs), dim=-1)
127
+ cos = freqs.cos() # (Seq, Dim)
128
+ sin = freqs.sin()
129
+ self.cache[D,seq_len,device,dtype] = (cos,sin)
130
+ return self.cache[D,seq_len,device,dtype]
131
+
132
+ @staticmethod
133
+ def rotate_half(x):
134
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
135
+ return torch.cat((-x2, x1), dim=-1)
136
+
137
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
138
+ assert pos1d.ndim==2
139
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
140
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
141
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
142
+
143
+ def forward(self, tokens, positions):
144
+ """
145
+ input:
146
+ * tokens: batch_size x nheads x ntokens x dim
147
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
148
+ output:
149
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
150
+ """
151
+ assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
152
+ D = tokens.size(3) // 2
153
+ assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
154
+ cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
155
+ # split features into two along the feature dimension, and apply rope1d on each half
156
+ y, x = tokens.chunk(2, dim=-1)
157
+ y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
158
+ x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
159
+ tokens = torch.cat((y, x), dim=-1)
160
+ return tokens
vggt/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ # try:
39
+ # if XFORMERS_ENABLED:
40
+ # from xformers.ops import SwiGLU
41
+
42
+ # XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ # else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ # raise ImportError
47
+ # except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
vggt/layers/vision_transformer.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.checkpoint import checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+ from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
20
+
21
+ logger = logging.getLogger("dinov2")
22
+
23
+
24
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
25
+ if not depth_first and include_root:
26
+ fn(module=module, name=name)
27
+ for child_name, child_module in module.named_children():
28
+ child_name = ".".join((name, child_name)) if name else child_name
29
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
30
+ if depth_first and include_root:
31
+ fn(module=module, name=name)
32
+ return module
33
+
34
+
35
+ class BlockChunk(nn.ModuleList):
36
+ def forward(self, x):
37
+ for b in self:
38
+ x = b(x)
39
+ return x
40
+
41
+
42
+ class DinoVisionTransformer(nn.Module):
43
+ def __init__(
44
+ self,
45
+ img_size=224,
46
+ patch_size=16,
47
+ in_chans=3,
48
+ embed_dim=768,
49
+ depth=12,
50
+ num_heads=12,
51
+ mlp_ratio=4.0,
52
+ qkv_bias=True,
53
+ ffn_bias=True,
54
+ proj_bias=True,
55
+ drop_path_rate=0.0,
56
+ drop_path_uniform=False,
57
+ init_values=None, # for layerscale: None or 0 => no layerscale
58
+ embed_layer=PatchEmbed,
59
+ act_layer=nn.GELU,
60
+ block_fn=Block,
61
+ ffn_layer="mlp",
62
+ block_chunks=1,
63
+ num_register_tokens=0,
64
+ interpolate_antialias=False,
65
+ interpolate_offset=0.1,
66
+ qk_norm=False,
67
+ ):
68
+ """
69
+ Args:
70
+ img_size (int, tuple): input image size
71
+ patch_size (int, tuple): patch size
72
+ in_chans (int): number of input channels
73
+ embed_dim (int): embedding dimension
74
+ depth (int): depth of transformer
75
+ num_heads (int): number of attention heads
76
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
77
+ qkv_bias (bool): enable bias for qkv if True
78
+ proj_bias (bool): enable bias for proj in attn if True
79
+ ffn_bias (bool): enable bias for ffn if True
80
+ drop_path_rate (float): stochastic depth rate
81
+ drop_path_uniform (bool): apply uniform drop rate across blocks
82
+ weight_init (str): weight init scheme
83
+ init_values (float): layer-scale init values
84
+ embed_layer (nn.Module): patch embedding layer
85
+ act_layer (nn.Module): MLP activation layer
86
+ block_fn (nn.Module): transformer block class
87
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
88
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
89
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
90
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
91
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
92
+ """
93
+ super().__init__()
94
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
95
+
96
+ # tricky but makes it work
97
+ self.use_checkpoint = False
98
+ #
99
+
100
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
101
+ self.num_tokens = 1
102
+ self.n_blocks = depth
103
+ self.num_heads = num_heads
104
+ self.patch_size = patch_size
105
+ self.num_register_tokens = num_register_tokens
106
+ self.interpolate_antialias = interpolate_antialias
107
+ self.interpolate_offset = interpolate_offset
108
+
109
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
110
+ num_patches = self.patch_embed.num_patches
111
+
112
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
113
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
114
+ assert num_register_tokens >= 0
115
+ self.register_tokens = (
116
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
117
+ )
118
+
119
+ if drop_path_uniform is True:
120
+ dpr = [drop_path_rate] * depth
121
+ else:
122
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
123
+
124
+ if ffn_layer == "mlp":
125
+ logger.info("using MLP layer as FFN")
126
+ ffn_layer = Mlp
127
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
128
+ logger.info("using SwiGLU layer as FFN")
129
+ ffn_layer = SwiGLUFFNFused
130
+ elif ffn_layer == "identity":
131
+ logger.info("using Identity layer as FFN")
132
+
133
+ def f(*args, **kwargs):
134
+ return nn.Identity()
135
+
136
+ ffn_layer = f
137
+ else:
138
+ raise NotImplementedError
139
+
140
+ blocks_list = [
141
+ block_fn(
142
+ dim=embed_dim,
143
+ num_heads=num_heads,
144
+ mlp_ratio=mlp_ratio,
145
+ qkv_bias=qkv_bias,
146
+ proj_bias=proj_bias,
147
+ ffn_bias=ffn_bias,
148
+ drop_path=dpr[i],
149
+ norm_layer=norm_layer,
150
+ act_layer=act_layer,
151
+ ffn_layer=ffn_layer,
152
+ init_values=init_values,
153
+ qk_norm=qk_norm,
154
+ )
155
+ for i in range(depth)
156
+ ]
157
+ if block_chunks > 0:
158
+ self.chunked_blocks = True
159
+ chunked_blocks = []
160
+ chunksize = depth // block_chunks
161
+ for i in range(0, depth, chunksize):
162
+ # this is to keep the block index consistent if we chunk the block list
163
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
164
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
165
+ else:
166
+ self.chunked_blocks = False
167
+ self.blocks = nn.ModuleList(blocks_list)
168
+
169
+ self.norm = norm_layer(embed_dim)
170
+ self.head = nn.Identity()
171
+
172
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
173
+
174
+ self.init_weights()
175
+
176
+ def init_weights(self):
177
+ trunc_normal_(self.pos_embed, std=0.02)
178
+ nn.init.normal_(self.cls_token, std=1e-6)
179
+ if self.register_tokens is not None:
180
+ nn.init.normal_(self.register_tokens, std=1e-6)
181
+ named_apply(init_weights_vit_timm, self)
182
+
183
+ def interpolate_pos_encoding(self, x, w, h):
184
+ previous_dtype = x.dtype
185
+ npatch = x.shape[1] - 1
186
+ N = self.pos_embed.shape[1] - 1
187
+ if npatch == N and w == h:
188
+ return self.pos_embed
189
+ pos_embed = self.pos_embed.float()
190
+ class_pos_embed = pos_embed[:, 0]
191
+ patch_pos_embed = pos_embed[:, 1:]
192
+ dim = x.shape[-1]
193
+ w0 = w // self.patch_size
194
+ h0 = h // self.patch_size
195
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
196
+ assert N == M * M
197
+ kwargs = {}
198
+ if self.interpolate_offset:
199
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
200
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
201
+ sx = float(w0 + self.interpolate_offset) / M
202
+ sy = float(h0 + self.interpolate_offset) / M
203
+ kwargs["scale_factor"] = (sx, sy)
204
+ else:
205
+ # Simply specify an output size instead of a scale factor
206
+ kwargs["size"] = (w0, h0)
207
+ patch_pos_embed = nn.functional.interpolate(
208
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
209
+ mode="bicubic",
210
+ antialias=self.interpolate_antialias,
211
+ **kwargs,
212
+ )
213
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
214
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
215
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
216
+
217
+ def prepare_tokens_with_masks(self, x, masks=None):
218
+ B, nc, w, h = x.shape
219
+ x = self.patch_embed(x)
220
+ if masks is not None:
221
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
222
+
223
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
224
+ x = x + self.interpolate_pos_encoding(x, w, h)
225
+
226
+ if self.register_tokens is not None:
227
+ x = torch.cat(
228
+ (
229
+ x[:, :1],
230
+ self.register_tokens.expand(x.shape[0], -1, -1),
231
+ x[:, 1:],
232
+ ),
233
+ dim=1,
234
+ )
235
+
236
+ return x
237
+
238
+ def forward_features_list(self, x_list, masks_list):
239
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
240
+
241
+
242
+ for blk in self.blocks:
243
+ if self.use_checkpoint:
244
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
245
+ else:
246
+ x = blk(x)
247
+
248
+ all_x = x
249
+ output = []
250
+ for x, masks in zip(all_x, masks_list):
251
+ x_norm = self.norm(x)
252
+ output.append(
253
+ {
254
+ "x_norm_clstoken": x_norm[:, 0],
255
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
256
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
257
+ "x_prenorm": x,
258
+ "masks": masks,
259
+ }
260
+ )
261
+ return output
262
+
263
+ def forward_features(self, x, masks=None):
264
+ if isinstance(x, list):
265
+ return self.forward_features_list(x, masks)
266
+
267
+ x = self.prepare_tokens_with_masks(x, masks)
268
+
269
+ for blk in self.blocks:
270
+ if self.use_checkpoint:
271
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
272
+ else:
273
+ x = blk(x)
274
+
275
+ x_norm = self.norm(x)
276
+ return {
277
+ "x_norm_clstoken": x_norm[:, 0],
278
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
279
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
280
+ "x_prenorm": x,
281
+ "masks": masks,
282
+ }
283
+
284
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
285
+ x = self.prepare_tokens_with_masks(x)
286
+ # If n is an int, take the n last blocks. If it's a list, take them
287
+ output, total_block_len = [], len(self.blocks)
288
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
289
+ for i, blk in enumerate(self.blocks):
290
+ x = blk(x)
291
+ if i in blocks_to_take:
292
+ output.append(x)
293
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
294
+ return output
295
+
296
+ def _get_intermediate_layers_chunked(self, x, n=1):
297
+ x = self.prepare_tokens_with_masks(x)
298
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
299
+ # If n is an int, take the n last blocks. If it's a list, take them
300
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
301
+ for block_chunk in self.blocks:
302
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
303
+ x = blk(x)
304
+ if i in blocks_to_take:
305
+ output.append(x)
306
+ i += 1
307
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
308
+ return output
309
+
310
+ def get_intermediate_layers(
311
+ self,
312
+ x: torch.Tensor,
313
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
314
+ reshape: bool = False,
315
+ return_class_token: bool = False,
316
+ norm=True,
317
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
318
+ if self.chunked_blocks:
319
+ outputs = self._get_intermediate_layers_chunked(x, n)
320
+ else:
321
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
322
+ if norm:
323
+ outputs = [self.norm(out) for out in outputs]
324
+ class_tokens = [out[:, 0] for out in outputs]
325
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
326
+ if reshape:
327
+ B, _, w, h = x.shape
328
+ outputs = [
329
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
330
+ for out in outputs
331
+ ]
332
+ if return_class_token:
333
+ return tuple(zip(outputs, class_tokens))
334
+ return tuple(outputs)
335
+
336
+ def forward(self, *args, is_training=True, **kwargs):
337
+ ret = self.forward_features(*args, **kwargs)
338
+ if is_training:
339
+ return ret
340
+ else:
341
+ return self.head(ret["x_norm_clstoken"])
342
+
343
+
344
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
345
+ """ViT weight initialization, original timm impl (for reproducibility)"""
346
+ if isinstance(module, nn.Linear):
347
+ trunc_normal_(module.weight, std=0.02)
348
+ if module.bias is not None:
349
+ nn.init.zeros_(module.bias)
350
+
351
+
352
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
353
+ model = DinoVisionTransformer(
354
+ patch_size=patch_size,
355
+ embed_dim=384,
356
+ depth=12,
357
+ num_heads=6,
358
+ mlp_ratio=4,
359
+ block_fn=partial(Block, attn_class=MemEffAttention),
360
+ num_register_tokens=num_register_tokens,
361
+ **kwargs,
362
+ )
363
+ return model
364
+
365
+
366
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
367
+ model = DinoVisionTransformer(
368
+ patch_size=patch_size,
369
+ embed_dim=768,
370
+ depth=12,
371
+ num_heads=12,
372
+ mlp_ratio=4,
373
+ block_fn=partial(Block, attn_class=MemEffAttention),
374
+ num_register_tokens=num_register_tokens,
375
+ **kwargs,
376
+ )
377
+ return model
378
+
379
+
380
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
381
+ model = DinoVisionTransformer(
382
+ patch_size=patch_size,
383
+ embed_dim=1024,
384
+ depth=24,
385
+ num_heads=16,
386
+ mlp_ratio=4,
387
+ block_fn=partial(Block, attn_class=MemEffAttention),
388
+ num_register_tokens=num_register_tokens,
389
+ **kwargs,
390
+ )
391
+ return model
392
+
393
+
394
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
395
+ """
396
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
397
+ """
398
+ model = DinoVisionTransformer(
399
+ patch_size=patch_size,
400
+ embed_dim=1536,
401
+ depth=40,
402
+ num_heads=24,
403
+ mlp_ratio=4,
404
+ block_fn=partial(Block, attn_class=MemEffAttention),
405
+ num_register_tokens=num_register_tokens,
406
+ **kwargs,
407
+ )
408
+ return model
vggt/models/aggregator.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import logging
9
+
10
+ import pdb
11
+ import math
12
+ import numpy as np
13
+ import os
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from hydra.utils import instantiate
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from torch.utils.checkpoint import checkpoint
21
+ from omegaconf import OmegaConf
22
+ from contextlib import nullcontext
23
+
24
+ from typing import Any, Dict, List, Optional, Tuple, Union
25
+
26
+ # from off3d.utils.train_utils import remove_if_not_match
27
+
28
+ # from off3d.models.modules import AttnBlock, CrossAttnBlock, Mlp, ResidualBlock, RoPEAttnBlock
29
+ # from vggsfm.models.utils import get_2d_sincos_pos_embed, get_1d_sincos_pos_embed_from_grid
30
+ # from off3d.models.dino_layers import SwiGLUFFNFused, PatchEmbed
31
+
32
+ from vggt.layers import SwiGLUFFNFused, PatchEmbed
33
+ from vggt.layers.block import Block
34
+
35
+ # from off3d.models.dino_layers.block import Block
36
+ # from vggt.layers.rope import RoPE2D, PositionGetter
37
+ from vggt.layers.rope import RoPE2D, PositionGetter
38
+
39
+ # from off3d.models.multihead_with_qk_norm import MultiheadAttention_with_qk_norm
40
+ # from off3d.models.rope import RoPEMulitheadAttention
41
+
42
+ from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
43
+
44
+
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+
49
+ _RESNET_MEAN = [0.485, 0.456, 0.406]
50
+ _RESNET_STD = [0.229, 0.224, 0.225]
51
+
52
+
53
+
54
+
55
+ class Aggregator(nn.Module):
56
+ def __init__(
57
+ self,
58
+ image_size = 512,
59
+ patch_size = 16,
60
+ num_register_tokens = 4,
61
+ image_backbone = "dinov2_vitl14_reg",
62
+ aa_block_size = 1,
63
+ aa_layer_size = 24,
64
+ aa_block_kwargs = Dict,
65
+ attn_block = Block,
66
+ aa_order = ["frame", "global"],
67
+ use_checkpoint = False,
68
+ use_reentrant = False,
69
+ use_dino_tokens = False,
70
+ use_patch_tokens_only = False,
71
+ freeze_dino=False,
72
+ freeze_dino_inter=False,
73
+ # pose_embed=False,
74
+ embed_type="no",
75
+ patch_embed_by_conv=False,
76
+ decoder_load_dino=False,
77
+ backbone_qk_norm=False,
78
+ **kwargs,
79
+ ):
80
+ super().__init__()
81
+
82
+ if image_backbone is None:
83
+ self.image_backbone = None
84
+ else:
85
+ self.__build_image_backbone__(image_backbone, image_size,
86
+ patch_size, num_register_tokens, freeze_dino=freeze_dino,
87
+ freeze_dino_inter=freeze_dino_inter, backbone_qk_norm=backbone_qk_norm)
88
+
89
+
90
+ self.freeze_dino = freeze_dino
91
+
92
+ if use_checkpoint and not freeze_dino:
93
+ self.image_backbone.use_checkpoint = True
94
+ else:
95
+ self.image_backbone.use_checkpoint = False
96
+
97
+ self.image_backbone.use_reentrant = use_reentrant
98
+
99
+ if aa_block_kwargs['rope_freq']>0:
100
+ self.rope = RoPE2D(freq=aa_block_kwargs['rope_freq'])
101
+ self.position_getter = PositionGetter()
102
+ else:
103
+ self.rope = None
104
+
105
+ frame_blocks_list = []
106
+ global_blocks_list = []
107
+ for _ in range(aa_layer_size):
108
+ frame_blocks_list.append(attn_block(**aa_block_kwargs, rope=self.rope))
109
+ global_blocks_list.append(attn_block(**aa_block_kwargs, rope=self.rope))
110
+
111
+ self.frame_blocks = nn.ModuleList(frame_blocks_list)
112
+ self.global_blocks = nn.ModuleList(global_blocks_list)
113
+
114
+ if "mlp" in embed_type:
115
+ self.register_mlp = nn.ModuleList([nn.Linear(aa_block_kwargs['dim'], aa_block_kwargs['dim']) for _ in range(aa_layer_size)])
116
+
117
+ self.aa_order = aa_order
118
+ self.aa_block_size = aa_block_size
119
+ self.aa_layer_size = aa_layer_size
120
+
121
+ assert self.aa_layer_size % self.aa_block_size == 0, "aa_layer_size must be divisible by aa_block_size"
122
+ self.aa_block_num = self.aa_layer_size // self.aa_block_size
123
+
124
+ self.patch_size = patch_size
125
+ self.use_checkpoint = use_checkpoint
126
+ self.use_reentrant = use_reentrant
127
+ self.use_dino_tokens = use_dino_tokens
128
+ self.use_patch_tokens_only = use_patch_tokens_only
129
+ # self.pose_embed = pose_embed
130
+ # self.register_embed = register_embed
131
+ self.embed_type = embed_type
132
+
133
+ if self.use_patch_tokens_only:
134
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, 1, aa_block_kwargs['dim']))
135
+ self.patch_start_idx = 0
136
+ nn.init.normal_(self.query_ref_token, std=1e-6)
137
+ elif self.use_dino_tokens:
138
+ # One for query frame and one for other frames
139
+ self.query_ref_token = nn.Parameter(torch.randn(1, 2, 1, aa_block_kwargs['dim']))
140
+ self.patch_start_idx = 1 + num_register_tokens + 1
141
+ nn.init.normal_(self.query_ref_token, std=1e-6)
142
+ else:
143
+ self.pose_token = nn.Parameter(torch.randn(1, 2, 1, aa_block_kwargs['dim']))
144
+ self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, aa_block_kwargs['dim']))
145
+ self.patch_start_idx = 1 + num_register_tokens
146
+ nn.init.normal_(self.pose_token, std=1e-6)
147
+ nn.init.normal_(self.register_token, std=1e-6)
148
+
149
+
150
+ if decoder_load_dino:
151
+ dinov2_weights = self.image_backbone.state_dict()
152
+ decoder_dinov2_weights = dino_to_aggregator(dinov2_weights)
153
+ missing_keys, unexpected_keys = self.load_state_dict(decoder_dinov2_weights, strict=False)
154
+ print(f"missing_keys for decoder_load_dino: {missing_keys}")
155
+ print(f"unexpected_keys for decoder_load_dino: {unexpected_keys}")
156
+
157
+ if patch_embed_by_conv:
158
+ self.image_backbone = self.image_backbone.patch_embed
159
+
160
+
161
+ for name, value in (
162
+ ("_resnet_mean", _RESNET_MEAN),
163
+ ("_resnet_std", _RESNET_STD),
164
+ ):
165
+ self.register_buffer(
166
+ name,
167
+ torch.FloatTensor(value).view(1, 1, 3, 1, 1),
168
+ persistent=False,
169
+ )
170
+
171
+
172
+ def __build_image_backbone__(self, image_backbone, image_size, patch_size, num_register_tokens,
173
+ interpolate_antialias=True,
174
+ interpolate_offset=0.0,
175
+ block_chunks=0,
176
+ init_values=1.0,
177
+ freeze_dino=False,
178
+ freeze_dino_inter=False,
179
+ backbone_qk_norm=False,
180
+ ):
181
+
182
+ vit_models = { "dinov2_vitl14_reg": vit_large,
183
+ "dinov2_vitb14_reg": vit_base,
184
+ "dinov2_vits14_reg": vit_small,
185
+ "dinov2_vitg2_reg": vit_giant2,
186
+ }
187
+
188
+ if image_backbone not in vit_models:
189
+ raise NotImplementedError
190
+
191
+ self.image_backbone = vit_models[image_backbone](img_size=image_size,
192
+ patch_size=patch_size, num_register_tokens=num_register_tokens,
193
+ interpolate_antialias=interpolate_antialias,
194
+ interpolate_offset=interpolate_offset,
195
+ block_chunks=block_chunks, init_values=init_values, qk_norm=backbone_qk_norm)
196
+
197
+ # pretrained_model = torch.hub.load("facebookresearch/dinov2", image_backbone)
198
+ # pretrained_model_dict = pretrained_model.state_dict()
199
+ # image_backbone_dict = self.image_backbone.state_dict()
200
+
201
+ # all_pretrained_keys = list(pretrained_model_dict.keys())
202
+
203
+ # for cur_key in all_pretrained_keys:
204
+ # pretrained_model_dict = remove_if_not_match(image_backbone_dict, pretrained_model_dict, cur_key)
205
+
206
+ # missing_keys, unexpected_keys = self.image_backbone.load_state_dict(pretrained_model_dict, strict=False)
207
+
208
+ self.image_backbone.mask_token.requires_grad_(False)
209
+ # self.image_backbone.freeze_dino = freeze_dino
210
+
211
+ # if freeze_dino:
212
+ # print("Freezing DINO layers")
213
+ # for name, param in self.image_backbone.named_parameters():
214
+ # param.requires_grad_(False)
215
+
216
+ # if freeze_dino_inter:
217
+ # print("Freezing DINO intermediate layers")
218
+ # for name, param in self.image_backbone.named_parameters():
219
+ # if name not in ['pos_embed', 'patch_embed.proj.weight']:
220
+ # param.requires_grad_(False)
221
+
222
+
223
+ # print("Loading pretrained DINO v2 model: ")
224
+ # print(f"missing_keys: {missing_keys}")
225
+ # print("Loading pretrained DINO v2 model: ")
226
+ # print(f"unexpected_keys: {unexpected_keys}")
227
+
228
+
229
+ def forward(
230
+ self, images,
231
+ masks=None,
232
+ batch=None,
233
+ ):
234
+ """
235
+ TODO List:
236
+
237
+ """
238
+
239
+ # The input images are in the range of [0, 1]
240
+ B, S, C_in, H, W = images.shape
241
+ device = images.device
242
+
243
+
244
+ images = (images - self._resnet_mean) / self._resnet_std
245
+
246
+
247
+ if self.image_backbone is not None:
248
+ images = images.view(B * S, C_in, H, W)
249
+
250
+ with torch.no_grad() if self.freeze_dino else nullcontext():
251
+ backbone_output = self.image_backbone(images)
252
+
253
+ if isinstance(backbone_output, dict):
254
+ patch_tokens = backbone_output["x_norm_patchtokens"]
255
+ else:
256
+ patch_tokens = backbone_output
257
+
258
+ BS, P, C = patch_tokens.shape
259
+
260
+ if self.use_patch_tokens_only:
261
+ indicator_tokens = slice_expand_and_flatten(self.query_ref_token, B, S)
262
+ tokens = patch_tokens + indicator_tokens
263
+ elif self.use_dino_tokens:
264
+ dino_cls_token = backbone_output["x_norm_clstoken"][:, None] # BS, 1, C
265
+ dino_register_tokens = backbone_output["x_norm_regtokens"] # BS, num_register_tokens, C
266
+
267
+ indicator_tokens = slice_expand_and_flatten(self.query_ref_token, B, S)
268
+ tokens = torch.cat([dino_cls_token, dino_register_tokens, indicator_tokens, patch_tokens], dim=1)
269
+ else:
270
+ # B, S, P, C
271
+ pose_token = slice_expand_and_flatten(self.pose_token, B, S)
272
+ register_token = slice_expand_and_flatten(self.register_token, B, S)
273
+
274
+ tokens = torch.cat([pose_token, register_token, patch_tokens], dim=1)
275
+ else:
276
+ # well well I need to write this, hopefully in the near future
277
+ raise NotImplementedError
278
+
279
+
280
+ if self.rope is not None:
281
+ pos = self.position_getter(B*S, H//self.patch_size, W//self.patch_size, device=device)
282
+ else:
283
+ pos = None
284
+
285
+
286
+
287
+ if self.patch_start_idx > 0:
288
+ # shift the position by 1 so that the special tokens are at 0
289
+ pos = pos + 1
290
+ pos_special = torch.zeros(B*S, self.patch_start_idx, 2).to(device).to(pos.dtype)
291
+ pos = torch.cat([pos_special, pos], dim=1)
292
+
293
+
294
+ _, P, C = tokens.shape
295
+
296
+
297
+ frame_idx = 0
298
+ global_idx = 0
299
+ output_list = []
300
+
301
+
302
+ for aa_block_idx in range(self.aa_block_num):
303
+ for attn_type in self.aa_order:
304
+ if attn_type == "frame":
305
+ tokens, frame_idx, frame_intermediates = self._process_frame_attention(
306
+ tokens, B, S, P, C, frame_idx, self.aa_block_size, pos=pos
307
+ )
308
+ elif attn_type == "global":
309
+ tokens, global_idx, global_intermediates = self._process_global_attention(
310
+ tokens, B, S, P, C, global_idx, self.aa_block_size, pos=pos
311
+ )
312
+ else:
313
+ raise ValueError(f"Unknown attention type: {attn_type}")
314
+
315
+
316
+ # for frame_inter, global_inter in zip(frame_intermediates, global_intermediates):
317
+ # concat_inter = torch.cat([frame_inter, global_inter], dim=-1) # [B x S x P x 2C]
318
+ # output_list.append(concat_inter)
319
+
320
+ for i in range(len(frame_intermediates)):
321
+ # [B x S x P x 2C]
322
+ concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
323
+ output_list.append(concat_inter)
324
+
325
+
326
+ del concat_inter
327
+ del frame_intermediates
328
+ del global_intermediates
329
+ return output_list, None, self.patch_start_idx
330
+
331
+
332
+
333
+ def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, num_blocks, pos=None):
334
+ """
335
+ Process frame attention blocks.
336
+ """
337
+ if tokens.shape != (B*S, P, C):
338
+ tokens = tokens.view(B, S, P, C)
339
+ tokens = tokens.view(B*S, P, C)
340
+
341
+ if pos is not None and pos.shape != (B*S, P, 2):
342
+ pos = pos.view(B, S, P, 2)
343
+ pos = pos.view(B*S, P, 2)
344
+
345
+ intermediates = []
346
+
347
+ for _ in range(num_blocks):
348
+ if self.use_checkpoint:
349
+ tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
350
+ else:
351
+ tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
352
+ frame_idx += 1
353
+ intermediates.append(tokens.view(B, S, P, C))
354
+
355
+ return tokens, frame_idx, intermediates
356
+
357
+ def _process_global_attention(self, tokens, B, S, P, C, global_idx, num_blocks, pos=None):
358
+ """
359
+ Process global attention blocks.
360
+ """
361
+ # pose_embed
362
+
363
+ if tokens.shape != (B, S*P, C):
364
+ tokens = tokens.view(B, S, P, C)
365
+
366
+
367
+ ############################################################
368
+ # Frame embedding
369
+ if "register" in self.embed_type:
370
+ embed_tokens = tokens[:, :, 1:2, ...].clone()
371
+ if "gauss" in self.embed_type:
372
+ embed_tokens = torch.randn((B, S, 1, C),device=tokens.device, dtype=tokens.dtype)
373
+
374
+ if self.embed_type != "no":
375
+ embed_tokens = F.normalize(embed_tokens, dim=-1)
376
+
377
+ if "mlp" in self.embed_type:
378
+ embed_tokens = self.register_mlp[global_idx](embed_tokens)
379
+
380
+ if "mlpnorm" in self.embed_type:
381
+ embed_tokens = F.normalize(embed_tokens, dim=-1)
382
+ if "all" in self.embed_type:
383
+ tokens = tokens + embed_tokens
384
+ elif "part" in self.embed_type:
385
+ tokens[:, :, self.patch_start_idx:] = tokens[:, :, self.patch_start_idx:] + embed_tokens
386
+ else:
387
+ assert self.embed_type == "no"
388
+
389
+ if "postnorm" in self.embed_type:
390
+ tokens = F.normalize(tokens, dim=-1)
391
+ # tokens = self.embed_norm(tokens)
392
+ ############################################################
393
+
394
+
395
+
396
+ tokens = tokens.view(B, S*P, C)
397
+
398
+ if pos is not None and pos.shape != (B, S*P, 2):
399
+ pos = pos.view(B, S, P, 2)
400
+ pos = pos.view(B, S*P, 2)
401
+
402
+ intermediates = []
403
+ for _ in range(num_blocks):
404
+ if self.use_checkpoint:
405
+ tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
406
+ else:
407
+ tokens = self.global_blocks[global_idx](tokens, pos=pos)
408
+ global_idx += 1
409
+ intermediates.append(tokens.view(B, S, P, C))
410
+
411
+ return tokens, global_idx, intermediates
412
+
413
+
414
+
415
+
416
+ def slice_expand_and_flatten(token_tensor, B, S):
417
+ """
418
+ 1) Takes the first token (index=0) and the remaining tokens (index=1..S-1).
419
+ 2) Expands them along batch dimension B.
420
+ 3) Concatenates along the time/sequence dimension => (B, S, ...).
421
+ 4) Flattens the first two dims to produce => (B*S, ...).
422
+
423
+ Args:
424
+ token_tensor: a tensor expected to have shape (1, S, ...) or (some_batch, S, ...).
425
+ We'll slice along dim=1.
426
+ B: batch size.
427
+ S: number of frames/time-steps.
428
+
429
+ Returns:
430
+ Flattened token tensor of shape (B*S, ...).
431
+ """
432
+
433
+ # Slice out the "query" tokens => shape (1, 1, ...)
434
+ query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
435
+ # Slice out the "other" tokens => shape (1, S-1, ...)
436
+ others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
437
+ # Concatenate => shape (B, S, ...)
438
+ combined = torch.cat([query, others], dim=1)
439
+
440
+ # Finally flatten => shape (B*S, ...)
441
+ combined = combined.view(B * S, *combined.shape[2:])
442
+ return combined
443
+
444
+
445
+
446
+
447
+ def dino_to_aggregator(dinov2_weights):
448
+ new_dinov2_weights = {}
449
+ for key, value in dinov2_weights.items():
450
+ if "blocks" in key:
451
+ for new_attn_key in ["frame_blocks", "global_blocks"]:
452
+ new_key = key.replace("blocks", new_attn_key)
453
+ # if 'attn' in key:
454
+ # if "qkv.weight" in key:
455
+ # new_key = new_key.replace('qkv.weight', 'in_proj_weight')
456
+ # elif "qkv.bias" in key:
457
+ # new_key = new_key.replace('qkv.bias', 'in_proj_bias')
458
+ # elif 'proj.weight' in key:
459
+ # new_key = new_key.replace('proj.weight', 'out_proj.weight')
460
+ # elif 'proj.bias' in key:
461
+ # new_key = new_key.replace('proj.bias', 'out_proj.bias')
462
+ new_dinov2_weights[new_key] = value.clone()
463
+ return new_dinov2_weights
464
+
465
+
466
+
467
+
468
+ def remove_if_not_match(model_state_dict, state_dict, key):
469
+ if key in state_dict.keys() and key in model_state_dict.keys():
470
+ if state_dict[key].shape != model_state_dict[key].shape:
471
+ print(f"Warning: {key} shape mismatch, removing it")
472
+ del state_dict[key]
473
+ return state_dict
vggt/models/vggt.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ # from off3d.models.vggt.utils import random_mask_single_patch_vectorized # Removed unused import
8
+ from hydra.utils import instantiate
9
+ # from .loss import *
10
+
11
+ def configure_dict(module, **attributes):
12
+ if module:
13
+ for attr, value in attributes.items():
14
+ setattr(module, attr, value)
15
+
16
+
17
+ class VGGT(nn.Module):
18
+ def __init__(self,
19
+ AGGREGATOR: Dict,
20
+ CameraHead: Dict,
21
+ PointHead: Dict,
22
+ DepthHead: Dict,
23
+ MatchHead: Dict,
24
+ TrackHead: Dict,
25
+ num_register_tokens,
26
+ init_values,
27
+ qk_norm,
28
+ ffn_layer,
29
+ patch_size,
30
+ enable_head_mp=False,
31
+ **kwargs):
32
+ super().__init__()
33
+
34
+ config_attrs = {
35
+ 'patch_size': patch_size,
36
+ 'init_values': init_values,
37
+ 'qk_norm': qk_norm,
38
+ 'ffn_layer': ffn_layer,
39
+ 'num_register_tokens': num_register_tokens
40
+ }
41
+
42
+
43
+ if AGGREGATOR:
44
+ configure_dict(AGGREGATOR, **config_attrs)
45
+ self.aggregator = instantiate(AGGREGATOR, _recursive_=False)
46
+ else:
47
+ self.aggregator = None
48
+
49
+ if CameraHead:
50
+ configure_dict(CameraHead, **config_attrs)
51
+ CameraHead.loss_kwargs.pose_encoding_type = CameraHead.pose_encoding_type
52
+ self.camera_head_loss_kwargs = CameraHead.loss_kwargs
53
+ self.camera_head = instantiate(CameraHead, _recursive_=False)
54
+ else:
55
+ self.camera_head = None
56
+
57
+ if PointHead:
58
+ configure_dict(PointHead, **config_attrs)
59
+ self.point_head_loss_kwargs = PointHead.loss_kwargs
60
+ self.point_head = instantiate(PointHead, _recursive_=False)
61
+ else:
62
+ self.point_head = None
63
+
64
+ if DepthHead:
65
+ configure_dict(DepthHead, **config_attrs)
66
+ self.depth_head_loss_kwargs = DepthHead.loss_kwargs
67
+ self.depth_head = instantiate(DepthHead, _recursive_=False)
68
+ else:
69
+ self.depth_head = None
70
+
71
+ if MatchHead:
72
+ configure_dict(MatchHead, **config_attrs)
73
+ self.match_head_loss_kwargs = MatchHead.loss_kwargs
74
+ self.match_head = instantiate(MatchHead, _recursive_=False)
75
+ else:
76
+ self.match_head = None
77
+
78
+ if TrackHead:
79
+ configure_dict(TrackHead, **config_attrs)
80
+ self.track_head_loss_kwargs = TrackHead.loss_kwargs
81
+ self.track_head = instantiate(TrackHead, _recursive_=False)
82
+ else:
83
+ self.track_head = None
84
+
85
+ self.enable_head_mp = enable_head_mp
86
+ # self.mask_patch_ratio = mask_patch_ratio
87
+ # self.mask_patch_size = mask_patch_size
88
+
89
+
90
+ def forward(self, batch, device=None):
91
+ images = (batch["images"]) #.to(device) # B x S x 3 x H x W
92
+ # intrinsics = (batch["intrinsics"])#.to(device)
93
+ # extrinsics = (batch["extrinsics"])#.to(device)
94
+ B, S, C, H, W = images.shape
95
+
96
+
97
+ # if self.training and self.mask_patch_ratio > 0: # Commented out masking
98
+ # for _ in range(1000):
99
+ # print("Please do not use mask_patch_ratio for now")
100
+
101
+ # predictions = {} # Removed redundant dict
102
+
103
+ aggregated_tokens_list, _, patch_start_idx = self.aggregator(images, batch=batch)
104
+
105
+
106
+ # Pose branch
107
+ # TODO check pose encoding conversion # Removed TODO
108
+ # loss = 0
109
+
110
+
111
+ predictions = {}
112
+
113
+
114
+
115
+ # well by default we use amp for track head
116
+ if self.track_head is not None:
117
+ track_loss_dict = self.track_head(aggregated_tokens_list, batch=batch, patch_start_idx=patch_start_idx)
118
+ predictions.update(track_loss_dict)
119
+
120
+
121
+ with torch.cuda.amp.autocast(enabled=self.enable_head_mp):
122
+ if self.camera_head is not None:
123
+ pred_pose_enc_list = self.camera_head(aggregated_tokens_list, batch=batch, patch_start_idx=patch_start_idx)
124
+ camera_loss_dict = {}
125
+ camera_loss_dict["pred_extrinsic_list"] = pred_pose_enc_list
126
+ # with torch.cuda.amp.autocast(enabled=False):
127
+ # if not isinstance(pred_pose_enc_list, dict):
128
+ # camera_loss_dict, last_pred_extrinsic = camera_loss(pred_pose_enc_list, batch, **self.camera_head_loss_kwargs)
129
+ # predictions["pred_extrinsic"] = last_pred_extrinsic
130
+ # else:
131
+ # camera_loss_dict = pred_pose_enc_list
132
+ predictions.update(camera_loss_dict)
133
+
134
+ if self.point_head is not None:
135
+ pts3d, pts3d_conf = self.point_head(aggregated_tokens_list, batch=batch, patch_start_idx=patch_start_idx)
136
+ # with torch.cuda.amp.autocast(enabled=False):
137
+ # pts3d_loss_dict = point_loss(pts3d, pts3d_conf, batch, **self.point_head_loss_kwargs)
138
+ # predictions.update(pts3d_loss_dict)
139
+ predictions["pred_world_points"] = pts3d
140
+ predictions["pred_world_points_conf"] = pts3d_conf
141
+
142
+ if self.depth_head is not None:
143
+ depth, depth_conf = self.depth_head(aggregated_tokens_list, batch=batch, patch_start_idx=patch_start_idx)
144
+ # with torch.cuda.amp.autocast(enabled=False):
145
+ # depth_loss_dict = depth_loss(depth, depth_conf, batch, **self.depth_head_loss_kwargs)
146
+ # predictions.update(depth_loss_dict)
147
+ predictions["pred_depth"] = depth
148
+ predictions["pred_depth_conf"] = depth_conf
149
+
150
+ if self.match_head is not None:
151
+ match_loss_dict = self.match_head(aggregated_tokens_list, batch=batch, patch_start_idx=patch_start_idx)
152
+ predictions.update(match_loss_dict)
153
+
154
+ predictions.update(batch)
155
+
156
+ return predictions
vggt/utils/pose_enc.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .rotation import quat_to_mat, mat_to_quat
3
+ # from off3d.utils.metric import closed_form_inverse_OpenCV
4
+
5
+
6
+ def extri_intri_to_pose_encoding(
7
+ extrinsics,
8
+ intrinsics,
9
+ image_size_hw = None, # e.g., (256, 512)
10
+ pose_encoding_type="absT_quaR_FoV",
11
+ min_focal_length=0.1,
12
+ max_focal_length=10,):
13
+
14
+ # extrinsics: BxSx3x4
15
+ # intrinsics: BxSx3x3
16
+
17
+
18
+ if pose_encoding_type=="absT_quaR_FoV":
19
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
20
+ T = extrinsics[:, :, :3, 3] # BxSx3
21
+
22
+ quat = mat_to_quat(R)
23
+ # R_reverse = quat_to_mat(quat)
24
+ # Note the order of h and w here
25
+ H, W = image_size_hw
26
+ fov_h = 2 * torch.atan((H /2) / intrinsics[..., 1, 1])
27
+ fov_w = 2 * torch.atan((W /2) / intrinsics[..., 0, 0])
28
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
29
+ elif pose_encoding_type=="absT_quaR_OneFLM1":
30
+ # raise ValueError("Not checked after mitigrating to off3d.")
31
+ focal_length = intrinsics[:, :, [0,1], [0,1]] / max(image_size_hw)
32
+ focal_length = focal_length.mean(dim=-1)
33
+ focal_length = focal_length.clamp(min_focal_length, max_focal_length)
34
+ focal_length = focal_length - 1
35
+ R = extrinsics[:, :, :3, :3]
36
+ T = extrinsics[:, :, :3, 3]
37
+ quat = mat_to_quat(R)
38
+ pose_encoding = torch.cat([T, quat, focal_length[..., None]], dim=-1).float()
39
+ else:
40
+ raise NotImplementedError
41
+
42
+ return pose_encoding
43
+
44
+
45
+
46
+ def pose_encoding_to_extri_intri(
47
+ pose_encoding,
48
+ image_size_hw=None, # e.g., (256, 512)
49
+ min_focal_length=0.1,
50
+ max_focal_length=10,
51
+ pose_encoding_type="absT_quaR_FoV",
52
+ build_intrinsics=True):
53
+
54
+ intrinsics = None
55
+
56
+ if pose_encoding_type == "absT_quaR_FoV":
57
+ T = pose_encoding[..., :3]
58
+ quat = pose_encoding[..., 3:7]
59
+ fov_h = pose_encoding[..., 7]
60
+ fov_w = pose_encoding[..., 8]
61
+
62
+ R = quat_to_mat(quat)
63
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
64
+
65
+ if build_intrinsics:
66
+ H, W = image_size_hw
67
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
68
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
69
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
70
+ intrinsics[..., 0, 0] = fx
71
+ intrinsics[..., 1, 1] = fy
72
+ intrinsics[..., 0, 2] = W / 2
73
+ intrinsics[..., 1, 2] = H / 2
74
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
75
+ elif pose_encoding_type == "absT_quaR_OneFLM1":
76
+ T = pose_encoding[..., :3]
77
+ quat = pose_encoding[..., 3:7]
78
+ focal_length_encoded = pose_encoding[..., 7]
79
+ focal_length = (focal_length_encoded + 1).clamp(min_focal_length, max_focal_length)
80
+ focal_length = focal_length * max(image_size_hw)
81
+ R = quat_to_mat(quat)
82
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
83
+
84
+ if build_intrinsics:
85
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
86
+ intrinsics[..., 0, 0] = focal_length
87
+ intrinsics[..., 1, 1] = focal_length
88
+ intrinsics[..., 0, 2] = image_size_hw[1] / 2
89
+ intrinsics[..., 1, 2] = image_size_hw[0] / 2
90
+
91
+ # NOTE something is wrong here
92
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
93
+ # TODO fill the principle point here, I need to check it is hw or wh
94
+ else:
95
+ raise NotImplementedError
96
+
97
+ return extrinsics, intrinsics
98
+
99
+
100
+
101
+
102
+ def test_pose_encoding():
103
+ num_tests = 1000
104
+ batch_size = 4
105
+ num_cameras = 2
106
+ image_size_hw = (256, 512)
107
+ min_focal_length = 0.1
108
+ max_focal_length = 30
109
+ pose_encoding_type = "absT_quaR_OneFLM1"
110
+
111
+ for _ in range(num_tests):
112
+ # Generate random extrinsics and intrinsics
113
+ pose_encoding = torch.randn(batch_size, num_cameras, 8)
114
+
115
+ # converting forward and backward, and verifying the consistency
116
+ extrinsics, intrinsics = pose_encoding_to_extri_intri(pose_encoding, image_size_hw, min_focal_length, max_focal_length, pose_encoding_type)
117
+ pose_encoding_back = extri_intri_to_pose_encoding(extrinsics, intrinsics, image_size_hw, pose_encoding_type, min_focal_length, max_focal_length)
118
+ extrinsics_forward, intrinsics_forward = pose_encoding_to_extri_intri(pose_encoding_back, image_size_hw, min_focal_length, max_focal_length, pose_encoding_type)
119
+ pose_encoding_forward = extri_intri_to_pose_encoding(extrinsics_forward, intrinsics_forward, image_size_hw, pose_encoding_type, min_focal_length, max_focal_length)
120
+ assert torch.allclose(pose_encoding_forward[..., :7], pose_encoding_back[..., :7], atol=1e-5), "Pose encoding does not match!"
121
+ print("All tests passed!")
122
+
123
+ if __name__ == "__main__":
124
+ test_pose_encoding()
125
+
126
+
vggt/utils/rotation.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from PyTorch3D
2
+
3
+ import torch
4
+ import numpy as np
5
+ import torch.nn.functional as F
6
+ from scipy.spatial.transform import Rotation as R
7
+
8
+
9
+
10
+ def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
11
+ """
12
+ Quaternion Order: XYZW or say ijkr, scalar-last
13
+
14
+ Convert rotations given as quaternions to rotation matrices.
15
+ Args:
16
+ quaternions: quaternions with real part last,
17
+ as tensor of shape (..., 4).
18
+
19
+ Returns:
20
+ Rotation matrices as tensor of shape (..., 3, 3).
21
+ """
22
+ i, j, k, r = torch.unbind(quaternions, -1)
23
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
24
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
25
+
26
+ o = torch.stack(
27
+ (
28
+ 1 - two_s * (j * j + k * k),
29
+ two_s * (i * j - k * r),
30
+ two_s * (i * k + j * r),
31
+ two_s * (i * j + k * r),
32
+ 1 - two_s * (i * i + k * k),
33
+ two_s * (j * k - i * r),
34
+ two_s * (i * k - j * r),
35
+ two_s * (j * k + i * r),
36
+ 1 - two_s * (i * i + j * j),
37
+ ),
38
+ -1,
39
+ )
40
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
41
+
42
+
43
+ def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
44
+ """
45
+ Convert rotations given as rotation matrices to quaternions.
46
+
47
+ Args:
48
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
49
+
50
+ Returns:
51
+ quaternions with real part last, as tensor of shape (..., 4).
52
+ Quaternion Order: XYZW or say ijkr, scalar-last
53
+ """
54
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
55
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
56
+
57
+ batch_dim = matrix.shape[:-2]
58
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
59
+ matrix.reshape(batch_dim + (9,)), dim=-1
60
+ )
61
+
62
+ q_abs = _sqrt_positive_part(
63
+ torch.stack(
64
+ [
65
+ 1.0 + m00 + m11 + m22,
66
+ 1.0 + m00 - m11 - m22,
67
+ 1.0 - m00 + m11 - m22,
68
+ 1.0 - m00 - m11 + m22,
69
+ ],
70
+ dim=-1,
71
+ )
72
+ )
73
+
74
+ # we produce the desired quaternion multiplied by each of r, i, j, k
75
+ quat_by_rijk = torch.stack(
76
+ [
77
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
78
+ # `int`.
79
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
80
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
81
+ # `int`.
82
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
83
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
84
+ # `int`.
85
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
86
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
87
+ # `int`.
88
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
89
+ ],
90
+ dim=-2,
91
+ )
92
+
93
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
94
+ # the candidate won't be picked.
95
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
96
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
97
+
98
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
99
+ # forall i; we pick the best-conditioned one (with the largest denominator)
100
+ out = quat_candidates[
101
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
102
+ ].reshape(batch_dim + (4,))
103
+
104
+ # Convert from rijk to ijkr
105
+ out = out[..., [1, 2, 3, 0]]
106
+
107
+ out = standardize_quaternion(out)
108
+
109
+ return out
110
+
111
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
112
+ """
113
+ Returns torch.sqrt(torch.max(0, x))
114
+ but with a zero subgradient where x is 0.
115
+ """
116
+ ret = torch.zeros_like(x)
117
+ positive_mask = x > 0
118
+ if torch.is_grad_enabled():
119
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
120
+ else:
121
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
122
+ return ret
123
+
124
+
125
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
126
+ """
127
+ Convert a unit quaternion to a standard form: one in which the real
128
+ part is non negative.
129
+
130
+ Args:
131
+ quaternions: Quaternions with real part last,
132
+ as tensor of shape (..., 4).
133
+
134
+ Returns:
135
+ Standardized quaternions as tensor of shape (..., 4).
136
+ """
137
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
138
+
139
+
140
+ def quat_to_mat_scipy(quaternions: np.ndarray) -> np.ndarray:
141
+ rotation = R.from_quat(quaternions)
142
+ return rotation.as_matrix()
143
+
144
+ def mat_to_quat_scipy(matrix: np.ndarray) -> np.ndarray:
145
+ rotation = R.from_matrix(matrix)
146
+ return rotation.as_quat()
147
+
148
+
149
+ if __name__ == "__main__":
150
+
151
+ num_tests = 10000 # Number of tests to run
152
+ tolerance = 1e-6 # Tolerance for floating point comparison
153
+
154
+ for _ in range(num_tests):
155
+ # Generate random quaternions
156
+ quaternions = torch.randn(1024, 4)
157
+ quaternions = quaternions / torch.norm(quaternions, dim=-1, keepdim=True) # Normalize to unit quaternions
158
+
159
+ # Convert quaternion to matrix using PyTorch
160
+ matrices_torch = quat_to_mat(quaternions)
161
+
162
+ # Convert matrices back to quaternions using PyTorch
163
+ quaternions_back = mat_to_quat(matrices_torch)
164
+
165
+ # Standardize quaternions to handle the case where quaternions = -quaternions_back
166
+ quaternions = standardize_quaternion(quaternions)
167
+ quaternions_back = standardize_quaternion(quaternions_back)
168
+
169
+ # Check if the original and converted quaternions match
170
+ if not torch.allclose(quaternions, quaternions_back, atol=tolerance):
171
+ print("Mismatch found!")
172
+ print("Original quaternions:", quaternions)
173
+ print("Converted quaternions:", quaternions_back)
174
+ max_error = torch.max(torch.abs(quaternions - quaternions_back))
175
+ print("Max error:", max_error)
176
+ else:
177
+ print("All tests passed successfully!")
178
+
179
+ # write code here
180
+
181
+ # quaternions = torch.randn(1024, 4) * 20
182
+ # # quaternions = quaternions / torch.norm(quaternions, dim=-1, keepdim=True) # Normalize to unit quaternions
183
+
184
+ # # Convert quaternion to matrix using PyTorch
185
+ # matrices_torch = quat_to_mat(quaternions).numpy()
186
+
187
+ # # Convert quaternion to matrix using SciPy
188
+ # matrices_scipy = quat_to_mat_scipy(quaternions.numpy())
189
+
190
+ # # Convert matrices back to quaternions using PyTorch
191
+ # quaternions_torch = mat_to_quat(torch.from_numpy(matrices_scipy)).numpy()
192
+
193
+ # # Convert matrices back to quaternions using SciPy
194
+ # quaternions_scipy = mat_to_quat_scipy(matrices_torch)
195
+
196
+
197
+ # reconvert_mat_diff = quat_to_mat_scipy(quaternions_torch) - quat_to_mat_scipy(quaternions_scipy)
198
+ # # Compare results
199
+ # print("Matrix conversion difference:", np.linalg.norm(matrices_torch - matrices_scipy))
200
+ # print("Quaternion conversion difference:", np.linalg.norm(reconvert_mat_diff))
viser_fn.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualization utilities for 3D reconstruction results using Viser.
2
+
3
+ Provides tools to visualize predicted camera poses, 3D point clouds, and confidence
4
+ thresholding through an interactive web interface.
5
+ """
6
+
7
+ import time
8
+ from pathlib import Path
9
+ from typing import List, Optional
10
+
11
+ import numpy as np
12
+ import tyro
13
+ from tqdm.auto import tqdm
14
+ import cv2
15
+ import viser
16
+ import viser.transforms as tf
17
+ import glob
18
+ import os
19
+ from scipy.spatial.transform import Rotation as R
20
+ # from camera import closed_form_inverse_se3
21
+ import torch
22
+ import threading
23
+
24
+ def viser_wrapper(
25
+ pred_dict: dict,
26
+ port: int = None,
27
+ init_conf_threshold: float = 3.0,
28
+ ) -> None:
29
+ """Visualize
30
+ Args:
31
+ pred_dict: Dictionary containing predictions
32
+ port: Optional port number for the viser server. If None, a random port will be used.
33
+ """
34
+ print(f"Starting viser server on port {port}") # Debug print
35
+
36
+ server = viser.ViserServer(host="0.0.0.0", port=port)
37
+ # server = viser.ViserServer(port=port)
38
+ server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
39
+
40
+ # Unpack and preprocess inputs
41
+ images = pred_dict["images"]
42
+ world_points = pred_dict["pred_world_points"]
43
+ conf = pred_dict["pred_world_points_conf"]
44
+ extrinsics = pred_dict["last_pred_extrinsic"]
45
+
46
+ # Handle batch dimension if present
47
+ if len(images.shape) > 4:
48
+ images = images[0]
49
+ world_points = world_points[0]
50
+ conf = conf[0]
51
+ extrinsics = extrinsics[0]
52
+
53
+ colors = images.transpose(0, 2, 3, 1) # Convert to (B, H, W, C)
54
+
55
+ # Reshape for visualization
56
+ S, H, W, _ = world_points.shape
57
+ colors = (colors.reshape(-1, 3) * 255).astype(np.uint8) # Convert to 0-255 range
58
+ conf = conf.reshape(-1)
59
+ world_points = world_points.reshape(-1, 3)
60
+
61
+ # Calculate camera poses in world coordinates
62
+ cam_to_world = closed_form_inverse_se3(extrinsics)
63
+ extrinsics = cam_to_world[:, :3, :]
64
+
65
+ # Center scene for better visualization
66
+ scene_center = np.mean(world_points, axis=0)
67
+ world_points -= scene_center
68
+ extrinsics[..., -1] -= scene_center
69
+
70
+ # set points3d as world_points
71
+ points = world_points
72
+
73
+
74
+ # frame_mask
75
+
76
+ frame_indices = np.arange(S)
77
+ frame_indices = frame_indices[:, None, None] # Shape: (S, 1, 1, 1)
78
+ frame_indices = np.tile(frame_indices, (1, H, W)) # Shape: (S, H, W, 3)
79
+ frame_indices = frame_indices.reshape(-1)
80
+
81
+ ############################################################
82
+ ############################################################
83
+
84
+
85
+
86
+ gui_points_conf = server.gui.add_slider(
87
+ "Confidence Thres",
88
+ min=0.1,
89
+ max=20,
90
+ step=0.05,
91
+ initial_value=init_conf_threshold,
92
+ )
93
+
94
+
95
+
96
+ gui_point_size = server.gui.add_slider(
97
+ "Point size", min=0.00001, max=0.01, step=0.0001, initial_value=0.00001
98
+ )
99
+
100
+ # Change from "Frame Selector" to more descriptive name
101
+ gui_frame_selector = server.gui.add_dropdown(
102
+ "Filter by Frame", # More action-oriented name
103
+ options=["All"] + [str(i) for i in range(S)],
104
+ initial_value="All",
105
+ )
106
+
107
+ # Initial mask shows all points passing confidence threshold
108
+ init_conf_mask = conf > init_conf_threshold
109
+ point_cloud = server.scene.add_point_cloud(
110
+ name="viser_pcd",
111
+ points=points[init_conf_mask],
112
+ colors=colors[init_conf_mask],
113
+ point_size=gui_point_size.value,
114
+ point_shape="circle",
115
+ )
116
+
117
+
118
+
119
+ frames: List[viser.FrameHandle] = []
120
+
121
+ def visualize_frames(extrinsics: np.ndarray, intrinsics: np.ndarray, images: np.ndarray) -> None:
122
+ """Send all COLMAP elements to viser for visualization. This could be optimized
123
+ a ton!"""
124
+ extrinsics = np.copy(extrinsics)
125
+ # Remove existing image frames.
126
+ for frame in frames:
127
+ frame.remove()
128
+ frames.clear()
129
+
130
+
131
+ def attach_callback(
132
+ frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle
133
+ ) -> None:
134
+ @frustum.on_click
135
+ def _(_) -> None:
136
+ for client in server.get_clients().values():
137
+ client.camera.wxyz = frame.wxyz
138
+ client.camera.position = frame.position
139
+
140
+ img_ids = sorted(range(S))
141
+ for img_id in tqdm(img_ids):
142
+
143
+ cam_to_world = extrinsics[img_id]
144
+
145
+ T_world_camera = tf.SE3.from_matrix(cam_to_world)
146
+
147
+ ratio = 1
148
+ frame = server.scene.add_frame(
149
+ f"frame_{img_id}",
150
+ wxyz=T_world_camera.rotation().wxyz,
151
+ position=T_world_camera.translation(),
152
+ axes_length=0.05/ratio,
153
+ axes_radius=0.002/ratio,
154
+ origin_radius = 0.002/ratio
155
+ )
156
+
157
+
158
+ frames.append(frame)
159
+
160
+ img = images[img_id]
161
+ img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
162
+ # import pdb;pdb.set_trace()
163
+ H, W = img.shape[:2]
164
+ # fy = intrinsics[img_id, 1, 1] * H
165
+ fy = 1.1 * H
166
+ image = img
167
+ # image = image[::downsample_factor, ::downsample_factor]
168
+ frustum = server.scene.add_camera_frustum(
169
+ f"frame_{img_id}/frustum",
170
+ fov=2 * np.arctan2(H / 2, fy),
171
+ aspect=W / H,
172
+ scale=0.05/ratio,
173
+ image=image,
174
+ line_width=1.0,
175
+ # line_thickness=0.01,
176
+ )
177
+
178
+ attach_callback(frustum, frame)
179
+
180
+
181
+ @gui_points_conf.on_update
182
+ def _(_) -> None:
183
+ conf_mask = conf > gui_points_conf.value
184
+ frame_mask = np.ones_like(conf_mask) # Default to all frames
185
+ if gui_frame_selector.value != "All":
186
+ selected_idx = int(gui_frame_selector.value)
187
+ frame_mask = (frame_indices == selected_idx)
188
+
189
+ combined_mask = conf_mask & frame_mask
190
+ point_cloud.points = points[combined_mask]
191
+ point_cloud.colors = colors[combined_mask]
192
+
193
+ @gui_point_size.on_update
194
+ def _(_) -> None:
195
+ point_cloud.point_size = gui_point_size.value
196
+
197
+ @gui_frame_selector.on_update
198
+ def _(_) -> None:
199
+ """Update points based on frame selection."""
200
+ conf_mask = conf > gui_points_conf.value
201
+
202
+ if gui_frame_selector.value == "All":
203
+ # Show all points passing confidence threshold
204
+ point_cloud.points = points[conf_mask]
205
+ point_cloud.colors = colors[conf_mask]
206
+ else:
207
+ # Show only selected frame's points
208
+ selected_idx = int(gui_frame_selector.value)
209
+ frame_mask = (frame_indices == selected_idx)
210
+ combined_mask = conf_mask & frame_mask
211
+ point_cloud.points = points[combined_mask]
212
+ point_cloud.colors = colors[combined_mask]
213
+
214
+ # Move camera to selected frame
215
+ # if 0 <= selected_idx < len(frames):
216
+ # selected_frame = frames[selected_idx]
217
+ # for client in server.get_clients().values():
218
+ # client.camera.wxyz = selected_frame.wxyz
219
+ # client.camera.position = selected_frame.position
220
+
221
+
222
+ # Initial visualization
223
+ visualize_frames(extrinsics, None, images)
224
+
225
+ # # Start server update loop in a background thread
226
+ def server_loop():
227
+ while True:
228
+ time.sleep(1e-3) # Small sleep to prevent CPU hogging
229
+
230
+ thread = threading.Thread(target=server_loop, daemon=True)
231
+ thread.start()
232
+
233
+
234
+
235
+ def closed_form_inverse_se3(se3, R=None, T=None):
236
+ """
237
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
238
+
239
+ If `R` and `T` are provided, they must correspond to the rotation and translation
240
+ components of `se3`. Otherwise, they will be extracted from `se3`.
241
+
242
+ Args:
243
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
244
+ R (optional): Nx3x3 array or tensor of rotation matrices.
245
+ T (optional): Nx3x1 array or tensor of translation vectors.
246
+
247
+ Returns:
248
+ Inverted SE3 matrices with the same type and device as `se3`.
249
+
250
+ Shapes:
251
+ se3: (N, 4, 4)
252
+ R: (N, 3, 3)
253
+ T: (N, 3, 1)
254
+ """
255
+ # Check if se3 is a numpy array or a torch tensor
256
+ is_numpy = isinstance(se3, np.ndarray)
257
+
258
+ # Validate shapes
259
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
260
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
261
+
262
+ # Extract R and T if not provided
263
+ if R is None:
264
+ R = se3[:, :3, :3] # (N,3,3)
265
+ if T is None:
266
+ T = se3[:, :3, 3:] # (N,3,1)
267
+
268
+ # Transpose R
269
+ if is_numpy:
270
+ # Compute the transpose of the rotation for NumPy
271
+ R_transposed = np.transpose(R, (0, 2, 1))
272
+ # -R^T t for NumPy
273
+ top_right = -np.matmul(R_transposed, T)
274
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
275
+ else:
276
+ R_transposed = R.transpose(1, 2) # (N,3,3)
277
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
278
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
279
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
280
+
281
+ inverted_matrix[:, :3, :3] = R_transposed
282
+ inverted_matrix[:, :3, 3:] = top_right
283
+
284
+ return inverted_matrix