Commit
·
febf487
1
Parent(s):
68f369a
init
Browse files- .gitattributes +1 -0
- .gitignore +174 -0
- README.md +5 -5
- app.py +272 -0
- clean_app.py +229 -0
- config/base.yaml +96 -0
- demo_hf.py +149 -0
- gradio_util.py +297 -0
- requirements.txt +28 -0
- vggt/heads/camera_head.py +220 -0
- vggt/heads/dpt_head.py +521 -0
- vggt/heads/head_act.py +97 -0
- vggt/heads/track_head.py +267 -0
- vggt/heads/utils.py +309 -0
- vggt/layers/__init__.py +11 -0
- vggt/layers/attention.py +116 -0
- vggt/layers/block.py +275 -0
- vggt/layers/dino_head.py +58 -0
- vggt/layers/drop_path.py +34 -0
- vggt/layers/layer_scale.py +27 -0
- vggt/layers/mlp.py +40 -0
- vggt/layers/patch_embed.py +88 -0
- vggt/layers/rope.py +160 -0
- vggt/layers/swiglu_ffn.py +72 -0
- vggt/layers/vision_transformer.py +408 -0
- vggt/models/aggregator.py +473 -0
- vggt/models/vggt.py +156 -0
- vggt/utils/pose_enc.py +126 -0
- vggt/utils/rotation.py +200 -0
- viser_fn.py +284 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
examples/** filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# UV
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
#uv.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
116 |
+
.pdm.toml
|
117 |
+
.pdm-python
|
118 |
+
.pdm-build/
|
119 |
+
|
120 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
121 |
+
__pypackages__/
|
122 |
+
|
123 |
+
# Celery stuff
|
124 |
+
celerybeat-schedule
|
125 |
+
celerybeat.pid
|
126 |
+
|
127 |
+
# SageMath parsed files
|
128 |
+
*.sage.py
|
129 |
+
|
130 |
+
# Environments
|
131 |
+
.env
|
132 |
+
.venv
|
133 |
+
env/
|
134 |
+
venv/
|
135 |
+
ENV/
|
136 |
+
env.bak/
|
137 |
+
venv.bak/
|
138 |
+
|
139 |
+
# Spyder project settings
|
140 |
+
.spyderproject
|
141 |
+
.spyproject
|
142 |
+
|
143 |
+
# Rope project settings
|
144 |
+
.ropeproject
|
145 |
+
|
146 |
+
# mkdocs documentation
|
147 |
+
/site
|
148 |
+
|
149 |
+
# mypy
|
150 |
+
.mypy_cache/
|
151 |
+
.dmypy.json
|
152 |
+
dmypy.json
|
153 |
+
|
154 |
+
# Pyre type checker
|
155 |
+
.pyre/
|
156 |
+
|
157 |
+
# pytype static type analyzer
|
158 |
+
.pytype/
|
159 |
+
|
160 |
+
# Cython debug symbols
|
161 |
+
cython_debug/
|
162 |
+
|
163 |
+
# PyCharm
|
164 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
165 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
166 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
167 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
168 |
+
#.idea/
|
169 |
+
|
170 |
+
# Ruff stuff:
|
171 |
+
.ruff_cache/
|
172 |
+
|
173 |
+
# PyPI configuration file
|
174 |
+
.pypirc
|
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: cc-by-nc-4.0
|
|
|
1 |
---
|
2 |
+
title: vggt
|
3 |
+
emoji: 🏆
|
4 |
+
colorFrom: indigo
|
5 |
+
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.17.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
license: cc-by-nc-4.0
|
app.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
import spaces
|
7 |
+
import sys
|
8 |
+
import os
|
9 |
+
import socket
|
10 |
+
import webbrowser
|
11 |
+
sys.path.append('vggt/')
|
12 |
+
import shutil
|
13 |
+
from datetime import datetime
|
14 |
+
from demo_hf import demo_fn
|
15 |
+
from omegaconf import DictConfig, OmegaConf
|
16 |
+
import glob
|
17 |
+
import gc
|
18 |
+
import time
|
19 |
+
from viser_fn import viser_wrapper
|
20 |
+
|
21 |
+
|
22 |
+
def get_free_port():
|
23 |
+
"""Get a free port using socket."""
|
24 |
+
# return 80
|
25 |
+
# return 8080
|
26 |
+
# return 10088 # for debugging
|
27 |
+
# return 7860
|
28 |
+
# return 7888
|
29 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
30 |
+
s.bind(('', 0))
|
31 |
+
port = s.getsockname()[1]
|
32 |
+
return port
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
@spaces.GPU(duration=240)
|
38 |
+
def vggt_demo(
|
39 |
+
input_video,
|
40 |
+
input_image,
|
41 |
+
):
|
42 |
+
start_time = time.time()
|
43 |
+
gc.collect()
|
44 |
+
torch.cuda.empty_cache()
|
45 |
+
|
46 |
+
|
47 |
+
debug = False
|
48 |
+
|
49 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
50 |
+
target_dir = f"input_images_{timestamp}"
|
51 |
+
if os.path.exists(target_dir):
|
52 |
+
shutil.rmtree(target_dir)
|
53 |
+
|
54 |
+
os.makedirs(target_dir)
|
55 |
+
target_dir_images = target_dir + "/images"
|
56 |
+
os.makedirs(target_dir_images)
|
57 |
+
|
58 |
+
|
59 |
+
if input_video is not None:
|
60 |
+
if not isinstance(input_video, str):
|
61 |
+
input_video = input_video["video"]["path"]
|
62 |
+
|
63 |
+
cfg_file = "config/base.yaml"
|
64 |
+
cfg = OmegaConf.load(cfg_file)
|
65 |
+
|
66 |
+
if input_image is not None:
|
67 |
+
input_image = sorted(input_image)
|
68 |
+
# recon_num = len(input_image)
|
69 |
+
|
70 |
+
# Copy files to the new directory
|
71 |
+
for file_name in input_image:
|
72 |
+
shutil.copy(file_name, target_dir_images)
|
73 |
+
elif input_video is not None:
|
74 |
+
vs = cv2.VideoCapture(input_video)
|
75 |
+
|
76 |
+
fps = vs.get(cv2.CAP_PROP_FPS)
|
77 |
+
|
78 |
+
frame_rate = 1
|
79 |
+
frame_interval = int(fps * frame_rate)
|
80 |
+
|
81 |
+
video_frame_num = 0
|
82 |
+
count = 0
|
83 |
+
|
84 |
+
while True:
|
85 |
+
(gotit, frame) = vs.read()
|
86 |
+
count +=1
|
87 |
+
|
88 |
+
if not gotit:
|
89 |
+
break
|
90 |
+
|
91 |
+
if count % frame_interval == 0:
|
92 |
+
cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
|
93 |
+
video_frame_num+=1
|
94 |
+
|
95 |
+
# recon_num = video_frame_num
|
96 |
+
# if recon_num<3:
|
97 |
+
# return None, "Please input at least three frames"
|
98 |
+
else:
|
99 |
+
return None, "Uploading not finished or Incorrect input format"
|
100 |
+
|
101 |
+
|
102 |
+
print(f"Files have been copied to {target_dir_images}")
|
103 |
+
cfg.SCENE_DIR = target_dir
|
104 |
+
|
105 |
+
predictions = demo_fn(cfg)
|
106 |
+
|
107 |
+
# Get a free port for viser
|
108 |
+
viser_port = get_free_port()
|
109 |
+
|
110 |
+
# Start viser visualization in a separate thread/process
|
111 |
+
viser_wrapper(predictions, port=viser_port)
|
112 |
+
|
113 |
+
del predictions
|
114 |
+
gc.collect()
|
115 |
+
torch.cuda.empty_cache()
|
116 |
+
|
117 |
+
print(input_image)
|
118 |
+
print(input_video)
|
119 |
+
end_time = time.time()
|
120 |
+
execution_time = end_time - start_time
|
121 |
+
print(f"Execution time: {execution_time} seconds")
|
122 |
+
|
123 |
+
# Return None for the 3D model (since we're using viser) and the viser URL
|
124 |
+
# viser_url = f"Viser visualization is ready at: http://localhost:{viser_port}"
|
125 |
+
# print(viser_url) # Debug print
|
126 |
+
return None, viser_port
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
|
131 |
+
statue_video = "examples/videos/statue_video.mp4"
|
132 |
+
|
133 |
+
apple_video = "examples/videos/apple_video.mp4"
|
134 |
+
british_museum_video = "examples/videos/british_museum_video.mp4"
|
135 |
+
cake_video = "examples/videos/cake_video.mp4"
|
136 |
+
bonsai_video = "examples/videos/bonsai_video.mp4"
|
137 |
+
face_video = "examples/videos/in2n_face_video.mp4"
|
138 |
+
counter_video = "examples/videos/in2n_counter_video.mp4"
|
139 |
+
|
140 |
+
horns_video = "examples/videos/llff_horns_video.mp4"
|
141 |
+
person_video = "examples/videos/in2n_person_video.mp4"
|
142 |
+
|
143 |
+
flower_video = "examples/videos/llff_flower_video.mp4"
|
144 |
+
|
145 |
+
fern_video = "examples/videos/llff_fern_video.mp4"
|
146 |
+
|
147 |
+
drums_video = "examples/videos/drums_video.mp4"
|
148 |
+
|
149 |
+
kitchen_video = "examples/videos/kitchen_video.mp4"
|
150 |
+
|
151 |
+
###########################################################################################
|
152 |
+
apple_images = glob.glob(f'examples/apple/images/*')
|
153 |
+
bonsai_images = glob.glob(f'examples/bonsai/images/*')
|
154 |
+
cake_images = glob.glob(f'examples/cake/images/*')
|
155 |
+
british_museum_images = glob.glob(f'examples/british_museum/images/*')
|
156 |
+
face_images = glob.glob(f'examples/in2n_face/images/*')
|
157 |
+
counter_images = glob.glob(f'examples/in2n_counter/images/*')
|
158 |
+
|
159 |
+
horns_images = glob.glob(f'examples/llff_horns/images/*')
|
160 |
+
|
161 |
+
person_images = glob.glob(f'examples/in2n_person/images/*')
|
162 |
+
flower_images = glob.glob(f'examples/llff_flower/images/*')
|
163 |
+
|
164 |
+
fern_images = glob.glob(f'examples/llff_fern/images/*')
|
165 |
+
statue_images = glob.glob(f'examples/statue/images/*')
|
166 |
+
|
167 |
+
drums_images = glob.glob(f'examples/drums/images/*')
|
168 |
+
kitchen_images = glob.glob(f'examples/kitchen/images/*')
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
###########################################################################################
|
173 |
+
|
174 |
+
|
175 |
+
with gr.Blocks() as demo:
|
176 |
+
|
177 |
+
gr.Markdown("""
|
178 |
+
# 🏛️ VGGT: Visual Geometry Grounded Transformer
|
179 |
+
|
180 |
+
<div style="font-size: 16px; line-height: 1.2;">
|
181 |
+
Alpha version (testing).
|
182 |
+
</div>
|
183 |
+
""")
|
184 |
+
|
185 |
+
with gr.Row():
|
186 |
+
with gr.Column(scale=1):
|
187 |
+
input_video = gr.Video(label="Upload Video", interactive=True)
|
188 |
+
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
|
189 |
+
|
190 |
+
|
191 |
+
with gr.Column(scale=3):
|
192 |
+
viser_output = gr.HTML(
|
193 |
+
label="Viser Visualization",
|
194 |
+
value='''<div style="height: 520px; border: 1px solid #e0e0e0;
|
195 |
+
border-radius: 4px; padding: 16px;
|
196 |
+
display: flex; align-items: center;
|
197 |
+
justify-content: center">
|
198 |
+
3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)
|
199 |
+
</div>'''
|
200 |
+
)
|
201 |
+
|
202 |
+
log_output = gr.Textbox(label="Log")
|
203 |
+
|
204 |
+
with gr.Row():
|
205 |
+
submit_btn = gr.Button("Reconstruct", scale=1)
|
206 |
+
clear_btn = gr.ClearButton([input_video, input_images, viser_output, log_output], scale=1) #Modified viser_output
|
207 |
+
|
208 |
+
|
209 |
+
|
210 |
+
|
211 |
+
examples = [
|
212 |
+
[flower_video, flower_images],
|
213 |
+
[kitchen_video, kitchen_images],
|
214 |
+
# [person_video, person_images],
|
215 |
+
# [statue_video, statue_images],
|
216 |
+
# [drums_video, drums_images],
|
217 |
+
[counter_video, counter_images],
|
218 |
+
[fern_video, fern_images],
|
219 |
+
[horns_video, horns_images],
|
220 |
+
# [apple_video, apple_images],
|
221 |
+
# [bonsai_video, bonsai_images],
|
222 |
+
]
|
223 |
+
|
224 |
+
def process_example(video, images):
|
225 |
+
"""Wrapper function to ensure outputs are properly captured"""
|
226 |
+
model_output, log = vggt_demo(video, images)
|
227 |
+
|
228 |
+
# viser_wrapper(predictions, port=log)
|
229 |
+
# Get the hostname - use the actual hostname or IP where the server is running
|
230 |
+
# hostname = socket.gethostname()
|
231 |
+
|
232 |
+
# Extract port from log
|
233 |
+
port = log
|
234 |
+
|
235 |
+
# Create the viser URL using the hostname
|
236 |
+
# viser_url = f"http://{hostname}:{port}"
|
237 |
+
|
238 |
+
viser_url = f"http://localhost:{log}"
|
239 |
+
print(f"Viser URL: {viser_url}")
|
240 |
+
|
241 |
+
# Create the iframe HTML code. Set width and height appropriately.
|
242 |
+
iframe_code = f'<iframe src="{viser_url}" width="100%" height="520px"></iframe>'
|
243 |
+
|
244 |
+
|
245 |
+
# Return the iframe code to update the gr.HTML component
|
246 |
+
return iframe_code, f"Visualization running at {viser_url}"
|
247 |
+
|
248 |
+
|
249 |
+
# TODO: move the selection of port outside of the demo function
|
250 |
+
# so that we can cache examples
|
251 |
+
|
252 |
+
gr.Examples(examples=examples,
|
253 |
+
inputs=[input_video, input_images],
|
254 |
+
outputs=[viser_output, log_output], # Output to viser_output
|
255 |
+
fn=process_example, # Use our wrapper function
|
256 |
+
cache_examples=False,
|
257 |
+
examples_per_page=50,
|
258 |
+
)
|
259 |
+
|
260 |
+
submit_btn.click(
|
261 |
+
process_example, # Use the same wrapper function
|
262 |
+
[input_video, input_images],
|
263 |
+
[viser_output, log_output], # Output to viser_output
|
264 |
+
# concurrency_limit=1
|
265 |
+
)
|
266 |
+
|
267 |
+
# demo.launch(debug=True, share=True)
|
268 |
+
# demo.launch(server_name="0.0.0.0", server_port=8082, debug=True, share=False)
|
269 |
+
# demo.queue(max_size=20).launch(show_error=True, share=True)
|
270 |
+
demo.queue(max_size=20).launch(show_error=True) #, share=True, server_port=7888, server_name="0.0.0.0")
|
271 |
+
# demo.queue(max_size=20, concurrency_count=1).launch(debug=True, share=True)
|
272 |
+
########################################################################################################################
|
clean_app.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import gradio as gr
|
6 |
+
import sys
|
7 |
+
import os
|
8 |
+
import socket
|
9 |
+
import webbrowser
|
10 |
+
sys.path.append('vggt/')
|
11 |
+
import shutil
|
12 |
+
from datetime import datetime
|
13 |
+
from demo_hf import demo_fn
|
14 |
+
from omegaconf import DictConfig, OmegaConf
|
15 |
+
import glob
|
16 |
+
import gc
|
17 |
+
import time
|
18 |
+
from viser_fn import viser_wrapper
|
19 |
+
|
20 |
+
|
21 |
+
def get_free_port():
|
22 |
+
"""Get a free port using socket."""
|
23 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
24 |
+
s.bind(('', 0))
|
25 |
+
port = s.getsockname()[1]
|
26 |
+
return port
|
27 |
+
|
28 |
+
def vggt_demo(
|
29 |
+
input_video,
|
30 |
+
input_image,
|
31 |
+
):
|
32 |
+
start_time = time.time()
|
33 |
+
gc.collect()
|
34 |
+
torch.cuda.empty_cache()
|
35 |
+
|
36 |
+
|
37 |
+
debug = False
|
38 |
+
|
39 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
40 |
+
target_dir = f"input_images_{timestamp}"
|
41 |
+
if os.path.exists(target_dir):
|
42 |
+
shutil.rmtree(target_dir)
|
43 |
+
|
44 |
+
os.makedirs(target_dir)
|
45 |
+
target_dir_images = target_dir + "/images"
|
46 |
+
os.makedirs(target_dir_images)
|
47 |
+
|
48 |
+
|
49 |
+
if input_video is not None:
|
50 |
+
if not isinstance(input_video, str):
|
51 |
+
input_video = input_video["video"]["path"]
|
52 |
+
|
53 |
+
cfg_file = "config/base.yaml"
|
54 |
+
cfg = OmegaConf.load(cfg_file)
|
55 |
+
|
56 |
+
if input_image is not None:
|
57 |
+
input_image = sorted(input_image)
|
58 |
+
# recon_num = len(input_image)
|
59 |
+
|
60 |
+
# Copy files to the new directory
|
61 |
+
for file_name in input_image:
|
62 |
+
shutil.copy(file_name, target_dir_images)
|
63 |
+
elif input_video is not None:
|
64 |
+
vs = cv2.VideoCapture(input_video)
|
65 |
+
|
66 |
+
fps = vs.get(cv2.CAP_PROP_FPS)
|
67 |
+
|
68 |
+
frame_rate = 1
|
69 |
+
frame_interval = int(fps * frame_rate)
|
70 |
+
|
71 |
+
video_frame_num = 0
|
72 |
+
count = 0
|
73 |
+
|
74 |
+
while True:
|
75 |
+
(gotit, frame) = vs.read()
|
76 |
+
count +=1
|
77 |
+
|
78 |
+
if not gotit:
|
79 |
+
break
|
80 |
+
|
81 |
+
if count % frame_interval == 0:
|
82 |
+
cv2.imwrite(target_dir_images+"/"+f"{video_frame_num:06}.png", frame)
|
83 |
+
video_frame_num+=1
|
84 |
+
else:
|
85 |
+
return None, "Uploading not finished or Incorrect input format"
|
86 |
+
|
87 |
+
|
88 |
+
print(f"Files have been copied to {target_dir_images}")
|
89 |
+
cfg.SCENE_DIR = target_dir
|
90 |
+
|
91 |
+
predictions = demo_fn(cfg)
|
92 |
+
|
93 |
+
# Get a free port for viser
|
94 |
+
viser_port = get_free_port()
|
95 |
+
|
96 |
+
# Start viser visualization in a separate thread/process
|
97 |
+
viser_wrapper(predictions, port=viser_port)
|
98 |
+
|
99 |
+
del predictions
|
100 |
+
gc.collect()
|
101 |
+
torch.cuda.empty_cache()
|
102 |
+
|
103 |
+
print(input_image)
|
104 |
+
print(input_video)
|
105 |
+
end_time = time.time()
|
106 |
+
execution_time = end_time - start_time
|
107 |
+
print(f"Execution time: {execution_time} seconds")
|
108 |
+
return None, viser_port
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
|
113 |
+
statue_video = "examples/videos/statue_video.mp4"
|
114 |
+
|
115 |
+
apple_video = "examples/videos/apple_video.mp4"
|
116 |
+
british_museum_video = "examples/videos/british_museum_video.mp4"
|
117 |
+
cake_video = "examples/videos/cake_video.mp4"
|
118 |
+
bonsai_video = "examples/videos/bonsai_video.mp4"
|
119 |
+
face_video = "examples/videos/in2n_face_video.mp4"
|
120 |
+
counter_video = "examples/videos/in2n_counter_video.mp4"
|
121 |
+
|
122 |
+
horns_video = "examples/videos/llff_horns_video.mp4"
|
123 |
+
person_video = "examples/videos/in2n_person_video.mp4"
|
124 |
+
|
125 |
+
flower_video = "examples/videos/llff_flower_video.mp4"
|
126 |
+
|
127 |
+
fern_video = "examples/videos/llff_fern_video.mp4"
|
128 |
+
|
129 |
+
drums_video = "examples/videos/drums_video.mp4"
|
130 |
+
|
131 |
+
kitchen_video = "examples/videos/kitchen_video.mp4"
|
132 |
+
|
133 |
+
###########################################################################################
|
134 |
+
apple_images = glob.glob(f'examples/apple/images/*')
|
135 |
+
bonsai_images = glob.glob(f'examples/bonsai/images/*')
|
136 |
+
cake_images = glob.glob(f'examples/cake/images/*')
|
137 |
+
british_museum_images = glob.glob(f'examples/british_museum/images/*')
|
138 |
+
face_images = glob.glob(f'examples/in2n_face/images/*')
|
139 |
+
counter_images = glob.glob(f'examples/in2n_counter/images/*')
|
140 |
+
|
141 |
+
horns_images = glob.glob(f'examples/llff_horns/images/*')
|
142 |
+
|
143 |
+
person_images = glob.glob(f'examples/in2n_person/images/*')
|
144 |
+
flower_images = glob.glob(f'examples/llff_flower/images/*')
|
145 |
+
|
146 |
+
fern_images = glob.glob(f'examples/llff_fern/images/*')
|
147 |
+
statue_images = glob.glob(f'examples/statue/images/*')
|
148 |
+
|
149 |
+
drums_images = glob.glob(f'examples/drums/images/*')
|
150 |
+
kitchen_images = glob.glob(f'examples/kitchen/images/*')
|
151 |
+
|
152 |
+
|
153 |
+
|
154 |
+
###########################################################################################
|
155 |
+
|
156 |
+
|
157 |
+
with gr.Blocks() as demo:
|
158 |
+
|
159 |
+
gr.Markdown("""
|
160 |
+
# 🏛️ VGGT: Visual Geometry Grounded Transformer
|
161 |
+
|
162 |
+
<div style="font-size: 16px; line-height: 1.2;">
|
163 |
+
Alpha version (testing).
|
164 |
+
</div>
|
165 |
+
""")
|
166 |
+
|
167 |
+
with gr.Row():
|
168 |
+
with gr.Column(scale=1):
|
169 |
+
input_video = gr.Video(label="Upload Video", interactive=True)
|
170 |
+
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
|
171 |
+
|
172 |
+
|
173 |
+
with gr.Column(scale=3):
|
174 |
+
viser_output = gr.HTML(
|
175 |
+
label="Viser Visualization",
|
176 |
+
value='''<div style="height: 520px; border: 1px solid #e0e0e0;
|
177 |
+
border-radius: 4px; padding: 16px;
|
178 |
+
display: flex; align-items: center;
|
179 |
+
justify-content: center">
|
180 |
+
3D Reconstruction (Point Cloud and Camera Poses; Zoom in to see details)
|
181 |
+
</div>'''
|
182 |
+
)
|
183 |
+
|
184 |
+
log_output = gr.Textbox(label="Log")
|
185 |
+
|
186 |
+
with gr.Row():
|
187 |
+
submit_btn = gr.Button("Reconstruct", scale=1)
|
188 |
+
clear_btn = gr.ClearButton([input_video, input_images, viser_output, log_output], scale=1) #Modified viser_output
|
189 |
+
|
190 |
+
|
191 |
+
|
192 |
+
|
193 |
+
examples = [
|
194 |
+
[flower_video, flower_images],
|
195 |
+
[kitchen_video, kitchen_images],
|
196 |
+
[counter_video, counter_images],
|
197 |
+
[fern_video, fern_images],
|
198 |
+
[horns_video, horns_images],
|
199 |
+
]
|
200 |
+
|
201 |
+
def process_example(video, images):
|
202 |
+
"""Wrapper function to ensure outputs are properly captured"""
|
203 |
+
model_output, log = vggt_demo(video, images)
|
204 |
+
|
205 |
+
viser_url = f"http://localhost:{log}"
|
206 |
+
print(f"Viser URL: {viser_url}")
|
207 |
+
|
208 |
+
# Create the iframe HTML code. Set width and height appropriately.
|
209 |
+
iframe_code = f'<iframe src="{viser_url}" width="100%" height="520px"></iframe>'
|
210 |
+
|
211 |
+
return iframe_code, f"Visualization running at {viser_url}"
|
212 |
+
|
213 |
+
gr.Examples(examples=examples,
|
214 |
+
inputs=[input_video, input_images],
|
215 |
+
outputs=[viser_output, log_output], # Output to viser_output
|
216 |
+
fn=process_example, # Use our wrapper function
|
217 |
+
cache_examples=False,
|
218 |
+
examples_per_page=50,
|
219 |
+
)
|
220 |
+
|
221 |
+
|
222 |
+
|
223 |
+
submit_btn.click(
|
224 |
+
process_example, # Use the same wrapper function
|
225 |
+
[input_video, input_images],
|
226 |
+
[viser_output, log_output], # Output to viser_output
|
227 |
+
concurrency_limit=1
|
228 |
+
)
|
229 |
+
demo.queue(max_size=20).launch(show_error=True, share=True, server_port=7888, server_name="0.0.0.0")
|
config/base.yaml
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SCENE_DIR: examples/apple/
|
2 |
+
# examples/llff_horns_single/
|
3 |
+
# apple
|
4 |
+
# cake
|
5 |
+
|
6 |
+
_target_: vggt.models.vggt.VGGT #off3d.models.vggt.vggt.VGGT
|
7 |
+
|
8 |
+
num_register_tokens: 4 # 0 for no register tokens
|
9 |
+
ffn_layer: "mlp"
|
10 |
+
qk_norm: False # NOTE: is this correct?
|
11 |
+
patch_size: 14
|
12 |
+
init_values: 0.01
|
13 |
+
|
14 |
+
AGGREGATOR:
|
15 |
+
_target_: vggt.models.aggregator.Aggregator
|
16 |
+
patch_embed_by_conv: False
|
17 |
+
image_size: 518
|
18 |
+
use_checkpoint: True
|
19 |
+
use_reentrant: False
|
20 |
+
decoder_load_dino: False
|
21 |
+
backbone_qk_norm: False
|
22 |
+
aa_block_kwargs:
|
23 |
+
dim: 1024
|
24 |
+
num_heads: 16
|
25 |
+
mlp_ratio: 4
|
26 |
+
qkv_bias: True
|
27 |
+
proj_bias: True
|
28 |
+
ffn_bias: True
|
29 |
+
drop: 0.0
|
30 |
+
attn_drop: 0.0
|
31 |
+
init_values: 0.01
|
32 |
+
drop_path: 0.0
|
33 |
+
fused_attn: True
|
34 |
+
qk_norm: True
|
35 |
+
rope_freq: 100
|
36 |
+
|
37 |
+
|
38 |
+
CameraHead:
|
39 |
+
_target_: vggt.heads.camera_head.CameraHead #off3d.models.vggt.camera_head.CameraHead
|
40 |
+
pose_encoding_type: "absT_quaR_FoV"
|
41 |
+
new_trunk: True
|
42 |
+
trunk_depth: 4
|
43 |
+
# proj_dim: 768
|
44 |
+
qk_norm: True
|
45 |
+
init_values: 0.01
|
46 |
+
act_dict:
|
47 |
+
trans_act: "linear"
|
48 |
+
quat_act: "linear"
|
49 |
+
fl_act: "linear"
|
50 |
+
loss_kwargs:
|
51 |
+
loss_type: "l1"
|
52 |
+
gamma: 0.6
|
53 |
+
|
54 |
+
|
55 |
+
PointHead:
|
56 |
+
_target_: vggt.heads.dpt_head.DPTHead #off3d.models.vggt.dpt_head.DPTHead
|
57 |
+
# _target_: off3d.models.vggt.linear_head.LinearHead
|
58 |
+
dim_in: 2048
|
59 |
+
shallow_conv: False
|
60 |
+
normalize_act: "inv_log"
|
61 |
+
pos_embed: True
|
62 |
+
loss_kwargs:
|
63 |
+
gradient_loss: "normal"
|
64 |
+
# gradient_loss: "grad"
|
65 |
+
normalize_pred: False
|
66 |
+
valid_range: 0.98
|
67 |
+
gamma: 1.0
|
68 |
+
camera_centric_reg: -1.0
|
69 |
+
all_mean: True
|
70 |
+
|
71 |
+
DepthHead: null
|
72 |
+
# _target_: vggt.heads.dpt_head.DPTHead #off3d.models.vggt.dpt_head.DPTHead
|
73 |
+
# # _target_: off3d.models.vggt.linear_head.LinearHead
|
74 |
+
# dim_in: 2048
|
75 |
+
# patch_size: ${patch_size}
|
76 |
+
# output_dim: 2
|
77 |
+
# normalize_act: "exp" # or just relu?
|
78 |
+
# normalize_act_conf: "expp1"
|
79 |
+
# pos_embed: True
|
80 |
+
# loss_kwargs:
|
81 |
+
# loss_type: "conf"
|
82 |
+
# predict_disparity: False # or True
|
83 |
+
# gradient_loss: "grad"
|
84 |
+
# valid_range: 0.98
|
85 |
+
# gamma: 1.0
|
86 |
+
# all_mean: True
|
87 |
+
|
88 |
+
MatchHead: null
|
89 |
+
TrackHead: null
|
90 |
+
|
91 |
+
|
92 |
+
|
93 |
+
hydra:
|
94 |
+
output_subdir: NULL
|
95 |
+
run:
|
96 |
+
dir: .
|
demo_hf.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hydra
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
from hydra.utils import instantiate
|
5 |
+
from omegaconf import DictConfig
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms as TF
|
8 |
+
import glob
|
9 |
+
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
|
10 |
+
from viser_fn import viser_wrapper
|
11 |
+
|
12 |
+
|
13 |
+
# @hydra.main(config_path="config", config_name="base")
|
14 |
+
def demo_fn(cfg: DictConfig) -> None:
|
15 |
+
print(cfg)
|
16 |
+
model = instantiate(cfg, _recursive_=False)
|
17 |
+
|
18 |
+
if not torch.cuda.is_available():
|
19 |
+
raise ValueError("CUDA is not available. Check your environment.")
|
20 |
+
|
21 |
+
device = "cuda"
|
22 |
+
model = model.to(device)
|
23 |
+
|
24 |
+
_VGGT_URL = "https://huggingface.co/facebook/vggt_alpha/resolve/main/vggt_alpha_v0.pt"
|
25 |
+
|
26 |
+
# Reload model
|
27 |
+
pretrain_model = torch.hub.load_state_dict_from_url(_VGGT_URL)
|
28 |
+
|
29 |
+
if "model" in pretrain_model:
|
30 |
+
model_dict = pretrain_model["model"]
|
31 |
+
model.load_state_dict(model_dict, strict=False)
|
32 |
+
else:
|
33 |
+
model.load_state_dict(pretrain_model, strict=True)
|
34 |
+
|
35 |
+
|
36 |
+
# batch = torch.load("/fsx-repligen/jianyuan/cvpr2025_ckpts/batch.pth")
|
37 |
+
# y_hat_raw = torch.load("/fsx-repligen/jianyuan/cvpr2025_ckpts/y_hat.pth")
|
38 |
+
|
39 |
+
|
40 |
+
image_list = glob.glob(os.path.join(cfg.SCENE_DIR, "images", "*"))
|
41 |
+
image_list = sorted(image_list)
|
42 |
+
images = load_and_preprocess_images(image_list)
|
43 |
+
images = images[None].to(device)
|
44 |
+
|
45 |
+
|
46 |
+
batch = {"images": images}
|
47 |
+
|
48 |
+
with torch.no_grad():
|
49 |
+
with torch.cuda.amp.autocast(dtype=torch.float16):
|
50 |
+
y_hat = model(batch)
|
51 |
+
|
52 |
+
|
53 |
+
last_pred_pose_enc = y_hat["pred_extrinsic_list"][-1]
|
54 |
+
pose_encoding_type = cfg.CameraHead.pose_encoding_type
|
55 |
+
|
56 |
+
last_pred_extrinsic, _ = pose_encoding_to_extri_intri(last_pred_pose_enc.detach(), None, pose_encoding_type=pose_encoding_type, build_intrinsics=False)
|
57 |
+
|
58 |
+
y_hat["last_pred_extrinsic"] = last_pred_extrinsic
|
59 |
+
|
60 |
+
|
61 |
+
for key in y_hat.keys():
|
62 |
+
if isinstance(y_hat[key], torch.Tensor):
|
63 |
+
y_hat[key] = y_hat[key].cpu().numpy()
|
64 |
+
|
65 |
+
return y_hat
|
66 |
+
|
67 |
+
|
68 |
+
|
69 |
+
def load_and_preprocess_images(image_path_list):
|
70 |
+
# Check for empty list
|
71 |
+
if len(image_path_list) == 0:
|
72 |
+
raise ValueError("At least 1 image is required")
|
73 |
+
|
74 |
+
# 1. load images as RGB
|
75 |
+
# 2. resize images to (518, X, 3), where X is the resized width and X should be divisible by 14
|
76 |
+
# 3. normalize images to (0, 1)
|
77 |
+
# 4. concatenate images to (N, 3, 518, X), where N is the number of images
|
78 |
+
images = []
|
79 |
+
shapes = set()
|
80 |
+
to_tensor = TF.ToTensor()
|
81 |
+
|
82 |
+
# First process all images and collect their shapes
|
83 |
+
for image_path in image_path_list:
|
84 |
+
img = Image.open(image_path).convert("RGB")
|
85 |
+
width, height = img.size
|
86 |
+
new_width = 518
|
87 |
+
|
88 |
+
# Calculate height maintaining aspect ratio, divisible by 14
|
89 |
+
new_height = round(height * (new_width / width) / 14) * 14
|
90 |
+
|
91 |
+
# Resize with new dimensions (width, height)
|
92 |
+
|
93 |
+
img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
|
94 |
+
img = to_tensor(img) # Convert to tensor (0, 1)
|
95 |
+
|
96 |
+
# Center crop height if it's larger than 518
|
97 |
+
|
98 |
+
if new_height > 518:
|
99 |
+
start_y = (new_height - 518) // 2
|
100 |
+
img = img[:, start_y:start_y + 518, :]
|
101 |
+
|
102 |
+
shapes.add((img.shape[1], img.shape[2]))
|
103 |
+
images.append(img)
|
104 |
+
|
105 |
+
# Check if we have different shapes
|
106 |
+
if len(shapes) > 1:
|
107 |
+
print(f"Warning: Found images with different shapes: {shapes}")
|
108 |
+
# Find maximum dimensions
|
109 |
+
max_height = max(shape[0] for shape in shapes)
|
110 |
+
max_width = max(shape[1] for shape in shapes)
|
111 |
+
|
112 |
+
# Pad images if necessary
|
113 |
+
padded_images = []
|
114 |
+
for img in images:
|
115 |
+
h_padding = max_height - img.shape[1]
|
116 |
+
w_padding = max_width - img.shape[2]
|
117 |
+
|
118 |
+
if h_padding > 0 or w_padding > 0:
|
119 |
+
pad_top = h_padding // 2
|
120 |
+
pad_bottom = h_padding - pad_top
|
121 |
+
pad_left = w_padding // 2
|
122 |
+
pad_right = w_padding - pad_left
|
123 |
+
|
124 |
+
img = torch.nn.functional.pad(
|
125 |
+
img,
|
126 |
+
(pad_left, pad_right, pad_top, pad_bottom),
|
127 |
+
mode='constant',
|
128 |
+
value=1.0
|
129 |
+
)
|
130 |
+
padded_images.append(img)
|
131 |
+
images = padded_images
|
132 |
+
|
133 |
+
|
134 |
+
images = torch.stack(images) # concatenate images
|
135 |
+
|
136 |
+
# Ensure correct shape when single image
|
137 |
+
if len(image_path_list) == 1:
|
138 |
+
# Verify shape is (1, C, H, W)
|
139 |
+
if images.dim() == 3:
|
140 |
+
images = images.unsqueeze(0)
|
141 |
+
|
142 |
+
return images
|
143 |
+
|
144 |
+
|
145 |
+
# if __name__ == "__main__":
|
146 |
+
# y_hat = demo_fn()
|
147 |
+
# # viser_wrapper(y_hat, port=8080)
|
148 |
+
|
149 |
+
|
gradio_util.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
import os
|
3 |
+
|
4 |
+
import trimesh
|
5 |
+
import open3d as o3d
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib
|
10 |
+
from scipy.spatial.transform import Rotation
|
11 |
+
|
12 |
+
print("Successfully imported the packages for Gradio visualization")
|
13 |
+
except:
|
14 |
+
print(
|
15 |
+
f"Failed to import packages for Gradio visualization. Please disable gradio visualization"
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def visualize_by_gradio(glbfile):
|
20 |
+
"""
|
21 |
+
Set up and launch a Gradio interface to visualize a GLB file.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
glbfile (str): Path to the GLB file to be visualized.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def load_glb_file(glb_path):
|
28 |
+
# Check if the file exists and return the path or error message
|
29 |
+
if os.path.exists(glb_path):
|
30 |
+
return glb_path, "3D Model Loaded Successfully"
|
31 |
+
else:
|
32 |
+
return None, "File not found"
|
33 |
+
|
34 |
+
# Load the GLB file initially to check if it's valid
|
35 |
+
initial_model, log_message = load_glb_file(glbfile)
|
36 |
+
|
37 |
+
# Create the Gradio interface
|
38 |
+
with gr.Blocks() as demo:
|
39 |
+
gr.Markdown("# GLB File Viewer")
|
40 |
+
|
41 |
+
# 3D Model viewer component
|
42 |
+
model_viewer = gr.Model3D(
|
43 |
+
label="3D Model Viewer", height=600, value=initial_model
|
44 |
+
)
|
45 |
+
|
46 |
+
# Textbox for log output
|
47 |
+
log_output = gr.Textbox(label="Log", lines=2, value=log_message)
|
48 |
+
|
49 |
+
# Launch the Gradio interface
|
50 |
+
demo.launch(share=True)
|
51 |
+
|
52 |
+
|
53 |
+
def vggsfm_predictions_to_glb(predictions) -> trimesh.Scene:
|
54 |
+
"""
|
55 |
+
Converts VGG SFM predictions to a 3D scene represented as a GLB.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
predictions (dict): A dictionary containing model predictions.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
trimesh.Scene: A 3D scene object.
|
62 |
+
"""
|
63 |
+
# Convert predictions to numpy arrays
|
64 |
+
vertices_3d = predictions["points3D"].cpu().numpy()
|
65 |
+
colors_rgb = (predictions["points3D_rgb"].cpu().numpy() * 255).astype(
|
66 |
+
np.uint8
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
if True:
|
71 |
+
pcd = o3d.geometry.PointCloud()
|
72 |
+
pcd.points = o3d.utility.Vector3dVector(vertices_3d)
|
73 |
+
pcd.colors = o3d.utility.Vector3dVector(colors_rgb)
|
74 |
+
|
75 |
+
cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=1.0)
|
76 |
+
filtered_pcd = pcd.select_by_index(ind)
|
77 |
+
|
78 |
+
print(f"Filter out {len(vertices_3d) - len(filtered_pcd.points)} 3D points")
|
79 |
+
vertices_3d = np.asarray(filtered_pcd.points)
|
80 |
+
colors_rgb = np.asarray(filtered_pcd.colors).astype(np.uint8)
|
81 |
+
|
82 |
+
|
83 |
+
|
84 |
+
camera_matrices = predictions["extrinsics_opencv"].cpu().numpy()
|
85 |
+
|
86 |
+
# Calculate the 5th and 95th percentiles along each axis
|
87 |
+
lower_percentile = np.percentile(vertices_3d, 5, axis=0)
|
88 |
+
upper_percentile = np.percentile(vertices_3d, 95, axis=0)
|
89 |
+
|
90 |
+
# Calculate the diagonal length of the percentile bounding box
|
91 |
+
scene_scale = np.linalg.norm(upper_percentile - lower_percentile)
|
92 |
+
|
93 |
+
colormap = matplotlib.colormaps.get_cmap("gist_rainbow")
|
94 |
+
|
95 |
+
# Initialize a 3D scene
|
96 |
+
scene_3d = trimesh.Scene()
|
97 |
+
|
98 |
+
# Add point cloud data to the scene
|
99 |
+
point_cloud_data = trimesh.PointCloud(
|
100 |
+
vertices=vertices_3d, colors=colors_rgb
|
101 |
+
)
|
102 |
+
|
103 |
+
scene_3d.add_geometry(point_cloud_data)
|
104 |
+
|
105 |
+
# Prepare 4x4 matrices for camera extrinsics
|
106 |
+
num_cameras = len(camera_matrices)
|
107 |
+
extrinsics_matrices = np.zeros((num_cameras, 4, 4))
|
108 |
+
extrinsics_matrices[:, :3, :4] = camera_matrices
|
109 |
+
extrinsics_matrices[:, 3, 3] = 1
|
110 |
+
|
111 |
+
# Add camera models to the scene
|
112 |
+
for i in range(num_cameras):
|
113 |
+
world_to_camera = extrinsics_matrices[i]
|
114 |
+
camera_to_world = np.linalg.inv(world_to_camera)
|
115 |
+
rgba_color = colormap(i / num_cameras)
|
116 |
+
current_color = tuple(int(255 * x) for x in rgba_color[:3])
|
117 |
+
|
118 |
+
integrate_camera_into_scene(
|
119 |
+
scene_3d, camera_to_world, current_color, scene_scale
|
120 |
+
)
|
121 |
+
|
122 |
+
# Align scene to the observation of the first camera
|
123 |
+
scene_3d = apply_scene_alignment(scene_3d, extrinsics_matrices)
|
124 |
+
|
125 |
+
return scene_3d
|
126 |
+
|
127 |
+
|
128 |
+
def apply_scene_alignment(
|
129 |
+
scene_3d: trimesh.Scene, extrinsics_matrices: np.ndarray
|
130 |
+
) -> trimesh.Scene:
|
131 |
+
"""
|
132 |
+
Aligns the 3D scene based on the extrinsics of the first camera.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
scene_3d (trimesh.Scene): The 3D scene to be aligned.
|
136 |
+
extrinsics_matrices (np.ndarray): Camera extrinsic matrices.
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
trimesh.Scene: Aligned 3D scene.
|
140 |
+
"""
|
141 |
+
# Set transformations for scene alignment
|
142 |
+
opengl_conversion_matrix = get_opengl_conversion_matrix()
|
143 |
+
|
144 |
+
# Rotation matrix for alignment (180 degrees around the y-axis)
|
145 |
+
align_rotation = np.eye(4)
|
146 |
+
align_rotation[:3, :3] = Rotation.from_euler(
|
147 |
+
"y", 180, degrees=True
|
148 |
+
).as_matrix()
|
149 |
+
|
150 |
+
# Apply transformation
|
151 |
+
initial_transformation = (
|
152 |
+
np.linalg.inv(extrinsics_matrices[0])
|
153 |
+
@ opengl_conversion_matrix
|
154 |
+
@ align_rotation
|
155 |
+
)
|
156 |
+
scene_3d.apply_transform(initial_transformation)
|
157 |
+
return scene_3d
|
158 |
+
|
159 |
+
|
160 |
+
def integrate_camera_into_scene(
|
161 |
+
scene: trimesh.Scene,
|
162 |
+
transform: np.ndarray,
|
163 |
+
face_colors: tuple,
|
164 |
+
scene_scale: float,
|
165 |
+
):
|
166 |
+
"""
|
167 |
+
Integrates a fake camera mesh into the 3D scene.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
scene (trimesh.Scene): The 3D scene to add the camera model.
|
171 |
+
transform (np.ndarray): Transformation matrix for camera positioning.
|
172 |
+
face_colors (tuple): Color of the camera face.
|
173 |
+
scene_scale (float): Scale of the scene.
|
174 |
+
"""
|
175 |
+
|
176 |
+
cam_width = scene_scale * 0.05
|
177 |
+
cam_height = scene_scale * 0.1
|
178 |
+
|
179 |
+
# Create cone shape for camera
|
180 |
+
rot_45_degree = np.eye(4)
|
181 |
+
rot_45_degree[:3, :3] = Rotation.from_euler(
|
182 |
+
"z", 45, degrees=True
|
183 |
+
).as_matrix()
|
184 |
+
rot_45_degree[2, 3] = -cam_height
|
185 |
+
|
186 |
+
opengl_transform = get_opengl_conversion_matrix()
|
187 |
+
# Combine transformations
|
188 |
+
complete_transform = transform @ opengl_transform @ rot_45_degree
|
189 |
+
camera_cone_shape = trimesh.creation.cone(cam_width, cam_height, sections=4)
|
190 |
+
|
191 |
+
# Generate mesh for the camera
|
192 |
+
slight_rotation = np.eye(4)
|
193 |
+
slight_rotation[:3, :3] = Rotation.from_euler(
|
194 |
+
"z", 2, degrees=True
|
195 |
+
).as_matrix()
|
196 |
+
|
197 |
+
vertices_combined = np.concatenate(
|
198 |
+
[
|
199 |
+
camera_cone_shape.vertices,
|
200 |
+
0.95 * camera_cone_shape.vertices,
|
201 |
+
transform_points(slight_rotation, camera_cone_shape.vertices),
|
202 |
+
]
|
203 |
+
)
|
204 |
+
vertices_transformed = transform_points(
|
205 |
+
complete_transform, vertices_combined
|
206 |
+
)
|
207 |
+
|
208 |
+
mesh_faces = compute_camera_faces(camera_cone_shape)
|
209 |
+
|
210 |
+
# Add the camera mesh to the scene
|
211 |
+
camera_mesh = trimesh.Trimesh(
|
212 |
+
vertices=vertices_transformed, faces=mesh_faces
|
213 |
+
)
|
214 |
+
camera_mesh.visual.face_colors[:, :3] = face_colors
|
215 |
+
scene.add_geometry(camera_mesh)
|
216 |
+
|
217 |
+
|
218 |
+
def compute_camera_faces(cone_shape: trimesh.Trimesh) -> np.ndarray:
|
219 |
+
"""
|
220 |
+
Computes the faces for the camera mesh.
|
221 |
+
|
222 |
+
Args:
|
223 |
+
cone_shape (trimesh.Trimesh): The shape of the camera cone.
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
np.ndarray: Array of faces for the camera mesh.
|
227 |
+
"""
|
228 |
+
# Create pseudo cameras
|
229 |
+
faces_list = []
|
230 |
+
num_vertices_cone = len(cone_shape.vertices)
|
231 |
+
|
232 |
+
for face in cone_shape.faces:
|
233 |
+
if 0 in face:
|
234 |
+
continue
|
235 |
+
v1, v2, v3 = face
|
236 |
+
v1_offset, v2_offset, v3_offset = face + num_vertices_cone
|
237 |
+
v1_offset_2, v2_offset_2, v3_offset_2 = face + 2 * num_vertices_cone
|
238 |
+
|
239 |
+
faces_list.extend(
|
240 |
+
[
|
241 |
+
(v1, v2, v2_offset),
|
242 |
+
(v1, v1_offset, v3),
|
243 |
+
(v3_offset, v2, v3),
|
244 |
+
(v1, v2, v2_offset_2),
|
245 |
+
(v1, v1_offset_2, v3),
|
246 |
+
(v3_offset_2, v2, v3),
|
247 |
+
]
|
248 |
+
)
|
249 |
+
|
250 |
+
faces_list += [(v3, v2, v1) for v1, v2, v3 in faces_list]
|
251 |
+
return np.array(faces_list)
|
252 |
+
|
253 |
+
|
254 |
+
def transform_points(
|
255 |
+
transformation: np.ndarray, points: np.ndarray, dim: int = None
|
256 |
+
) -> np.ndarray:
|
257 |
+
"""
|
258 |
+
Applies a 4x4 transformation to a set of points.
|
259 |
+
|
260 |
+
Args:
|
261 |
+
transformation (np.ndarray): Transformation matrix.
|
262 |
+
points (np.ndarray): Points to be transformed.
|
263 |
+
dim (int, optional): Dimension for reshaping the result.
|
264 |
+
|
265 |
+
Returns:
|
266 |
+
np.ndarray: Transformed points.
|
267 |
+
"""
|
268 |
+
points = np.asarray(points)
|
269 |
+
initial_shape = points.shape[:-1]
|
270 |
+
dim = dim or points.shape[-1]
|
271 |
+
|
272 |
+
# Apply transformation
|
273 |
+
transformation = transformation.swapaxes(
|
274 |
+
-1, -2
|
275 |
+
) # Transpose the transformation matrix
|
276 |
+
points = points @ transformation[..., :-1, :] + transformation[..., -1:, :]
|
277 |
+
|
278 |
+
# Reshape the result
|
279 |
+
result = points[..., :dim].reshape(*initial_shape, dim)
|
280 |
+
return result
|
281 |
+
|
282 |
+
|
283 |
+
def get_opengl_conversion_matrix() -> np.ndarray:
|
284 |
+
"""
|
285 |
+
Constructs and returns the OpenGL conversion matrix.
|
286 |
+
|
287 |
+
Returns:
|
288 |
+
numpy.ndarray: A 4x4 OpenGL conversion matrix.
|
289 |
+
"""
|
290 |
+
# Create an identity matrix
|
291 |
+
matrix = np.identity(4)
|
292 |
+
|
293 |
+
# Flip the y and z axes
|
294 |
+
matrix[1, 1] = -1
|
295 |
+
matrix[2, 2] = -1
|
296 |
+
|
297 |
+
return matrix
|
requirements.txt
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.4.0
|
2 |
+
torchvision==0.19.0
|
3 |
+
hydra-core==1.3.2
|
4 |
+
scipy
|
5 |
+
omegaconf
|
6 |
+
opencv-python
|
7 |
+
einops
|
8 |
+
numpy==1.26.3
|
9 |
+
viser
|
10 |
+
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
# accelerate==0.24.0
|
15 |
+
# git+https://github.com/cvg/LightGlue.git#egg=LightGlue
|
16 |
+
# pycolmap==0.6.1
|
17 |
+
# https://huggingface.co/facebook/VGGSfM/resolve/main/poselib-2.0.2-cp310-cp310-linux_x86_64.whl
|
18 |
+
# trimesh
|
19 |
+
# open3d
|
20 |
+
|
21 |
+
# hydra-core==1.3.2
|
22 |
+
# scipy
|
23 |
+
# omegaconf
|
24 |
+
# opencv-python
|
25 |
+
# einops
|
26 |
+
# numpy==1.26.3
|
27 |
+
# trimesh
|
28 |
+
# open3d
|
vggt/heads/camera_head.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from hydra.utils import instantiate
|
14 |
+
|
15 |
+
from vggt.layers.block import Block
|
16 |
+
|
17 |
+
from vggt.layers import Mlp
|
18 |
+
from vggt.heads.utils import PoseEmbedding
|
19 |
+
from vggt.heads.head_act import activate_pose
|
20 |
+
|
21 |
+
def modulate(x, shift, scale):
|
22 |
+
# modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
|
23 |
+
return x * (1 + scale) + shift
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
class CameraHead(nn.Module):
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
dim_in=2048,
|
31 |
+
patch_size=14,
|
32 |
+
qk_norm=False,
|
33 |
+
trunk_depth=4,
|
34 |
+
new_trunk=True,
|
35 |
+
update_new_trunk_tokens=False,
|
36 |
+
pose_encoding_type="absT_quaR_FoV",
|
37 |
+
proj_dim=-1,
|
38 |
+
num_heads=16,
|
39 |
+
mlp_ratio=4,
|
40 |
+
init_values=None,
|
41 |
+
act_dict=None,
|
42 |
+
**kwargs,
|
43 |
+
):
|
44 |
+
super().__init__()
|
45 |
+
|
46 |
+
# Three types:
|
47 |
+
# 1. Linear projection
|
48 |
+
# 2. New trunk
|
49 |
+
# 3. Old trunk
|
50 |
+
|
51 |
+
self.new_trunk = new_trunk
|
52 |
+
if pose_encoding_type=="absT_quaR_FoV":
|
53 |
+
self.target_dim = 9
|
54 |
+
elif pose_encoding_type=="absT_quaR_OneFLM1":
|
55 |
+
self.target_dim = 8
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Unsupported pose encoding type: {pose_encoding_type}")
|
58 |
+
|
59 |
+
self.update_new_trunk_tokens = update_new_trunk_tokens
|
60 |
+
self.act_dict = act_dict
|
61 |
+
self.trunk_depth = trunk_depth
|
62 |
+
|
63 |
+
self.token_norm = nn.LayerNorm(dim_in)
|
64 |
+
|
65 |
+
if proj_dim > 0:
|
66 |
+
self.proj = nn.Linear(dim_in, proj_dim)
|
67 |
+
dim_in = proj_dim
|
68 |
+
else:
|
69 |
+
self.proj = nn.Identity()
|
70 |
+
|
71 |
+
if self.trunk_depth <0:
|
72 |
+
self.pose_branch = nn.Linear(dim_in, self.target_dim)
|
73 |
+
else:
|
74 |
+
self.trunk = nn.Sequential(
|
75 |
+
*[
|
76 |
+
Block(
|
77 |
+
dim=dim_in,
|
78 |
+
num_heads=num_heads,
|
79 |
+
mlp_ratio=mlp_ratio,
|
80 |
+
qk_norm=qk_norm,
|
81 |
+
init_values=init_values,
|
82 |
+
)
|
83 |
+
for _ in range(trunk_depth)
|
84 |
+
]
|
85 |
+
)
|
86 |
+
self.trunk_norm = nn.LayerNorm(dim_in)
|
87 |
+
|
88 |
+
if self.new_trunk:
|
89 |
+
# TODO: self.empty_pose_tokens -> BxSxC
|
90 |
+
self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
|
91 |
+
self.embed_pose = nn.Linear(self.target_dim, dim_in)
|
92 |
+
|
93 |
+
self.poseLN_modulation = nn.Sequential(
|
94 |
+
nn.SiLU(),
|
95 |
+
nn.Linear(dim_in, 3 * dim_in, bias=True)
|
96 |
+
)
|
97 |
+
|
98 |
+
self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
|
99 |
+
self.pose_branch = Mlp(
|
100 |
+
in_features=dim_in,
|
101 |
+
hidden_features=dim_in // 2,
|
102 |
+
out_features=self.target_dim,
|
103 |
+
drop=0,
|
104 |
+
)
|
105 |
+
else:
|
106 |
+
self.ffeat_norm = nn.LayerNorm(dim_in)
|
107 |
+
self.pose_branch = Mlp(
|
108 |
+
in_features=dim_in,
|
109 |
+
hidden_features=dim_in * 2,
|
110 |
+
out_features=dim_in + self.target_dim,
|
111 |
+
drop=0,
|
112 |
+
)
|
113 |
+
|
114 |
+
self.ffeat_updater = nn.Sequential(
|
115 |
+
nn.Linear(dim_in, dim_in), nn.GELU()
|
116 |
+
)
|
117 |
+
|
118 |
+
# sine and cosine embed for camera parameters
|
119 |
+
self.embed_pose = PoseEmbedding(
|
120 |
+
target_dim=self.target_dim,
|
121 |
+
n_harmonic_functions=(dim_in // self.target_dim) // 2,
|
122 |
+
append_input=False,
|
123 |
+
)
|
124 |
+
self.embed_pose_proj = nn.Linear(self.embed_pose.out_dim, dim_in)
|
125 |
+
|
126 |
+
|
127 |
+
def forward(self, aggregated_tokens_list, batch, patch_start_idx, iters=4,):
|
128 |
+
"""
|
129 |
+
"""
|
130 |
+
tokens = aggregated_tokens_list[-1]
|
131 |
+
# only use the Pose token for camera prediction
|
132 |
+
pose_tokens = tokens[:, :, 0]
|
133 |
+
pose_tokens = self.token_norm(pose_tokens)
|
134 |
+
pose_tokens = self.proj(pose_tokens)
|
135 |
+
|
136 |
+
B, S, C = pose_tokens.shape
|
137 |
+
|
138 |
+
if self.trunk_depth < 0:
|
139 |
+
pred_pose_enc = self.pose_branch(pose_tokens)
|
140 |
+
pred_pose_enc_list = [activate_pose(pred_pose_enc, **self.act_dict)]
|
141 |
+
elif self.new_trunk:
|
142 |
+
pred_pose_enc_list = self.new_trunk_fn(pose_tokens, iters)
|
143 |
+
else:
|
144 |
+
pred_pose_enc_list = self.old_trunk_fn(pose_tokens, iters)
|
145 |
+
|
146 |
+
|
147 |
+
# TODO add act here
|
148 |
+
return pred_pose_enc_list
|
149 |
+
|
150 |
+
|
151 |
+
def new_trunk_fn(self, pose_tokens, iters):
|
152 |
+
B, S, C = pose_tokens.shape
|
153 |
+
|
154 |
+
pred_pose_enc = None
|
155 |
+
pose_tokens_init = pose_tokens.clone()
|
156 |
+
|
157 |
+
pred_pose_enc_list = []
|
158 |
+
|
159 |
+
for iter_num in range(iters):
|
160 |
+
if pred_pose_enc is None:
|
161 |
+
# model_input = self.empty_representation BxSxC
|
162 |
+
module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
|
163 |
+
else:
|
164 |
+
pred_pose_enc = pred_pose_enc.detach()
|
165 |
+
module_input = self.embed_pose(pred_pose_enc)
|
166 |
+
|
167 |
+
shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
|
168 |
+
pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
|
169 |
+
pose_tokens_modulated = pose_tokens_modulated + pose_tokens
|
170 |
+
|
171 |
+
pose_tokens_modulated = self.trunk(pose_tokens_modulated)
|
172 |
+
pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
|
173 |
+
|
174 |
+
if pred_pose_enc is None:
|
175 |
+
pred_pose_enc = pred_pose_enc_delta
|
176 |
+
else:
|
177 |
+
pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
|
178 |
+
|
179 |
+
if self.update_new_trunk_tokens:
|
180 |
+
pose_tokens = pose_tokens_modulated + pose_tokens_init
|
181 |
+
|
182 |
+
pred_pose_enc_list.append(activate_pose(pred_pose_enc, **self.act_dict))
|
183 |
+
|
184 |
+
return pred_pose_enc_list
|
185 |
+
|
186 |
+
|
187 |
+
def old_trunk_fn(self, pose_tokens, iters):
|
188 |
+
B, S, C = pose_tokens.shape
|
189 |
+
|
190 |
+
pred_pose_enc = torch.zeros(B, S, self.target_dim).to(
|
191 |
+
pose_tokens.device
|
192 |
+
)
|
193 |
+
|
194 |
+
pose_tokens_init = pose_tokens.clone()
|
195 |
+
|
196 |
+
pred_pose_enc_list = []
|
197 |
+
|
198 |
+
for iter_num in range(iters):
|
199 |
+
pred_pose_enc = pred_pose_enc.detach()
|
200 |
+
|
201 |
+
# Embed the camera parameters and add to pose_tokens
|
202 |
+
pose_embed = self.embed_pose_proj(self.embed_pose(pred_pose_enc))
|
203 |
+
pose_tokens = pose_tokens + pose_embed
|
204 |
+
|
205 |
+
# Run trunk transformers on pose_tokens
|
206 |
+
pose_tokens = self.trunk(pose_tokens)
|
207 |
+
|
208 |
+
# Predict the delta feat and pose encoding at each iteration
|
209 |
+
delta = self.pose_branch(self.trunk_norm(pose_tokens))
|
210 |
+
delta_pred_pose_enc = delta[..., : self.target_dim]
|
211 |
+
delta_feat = delta[..., self.target_dim :]
|
212 |
+
|
213 |
+
pose_tokens = self.ffeat_updater(self.ffeat_norm(delta_feat)) + pose_tokens
|
214 |
+
|
215 |
+
pred_pose_enc = pred_pose_enc + delta_pred_pose_enc
|
216 |
+
pose_tokens = (pose_tokens + pose_tokens_init) / 2
|
217 |
+
pred_pose_enc_list.append(activate_pose(pred_pose_enc, **self.act_dict))
|
218 |
+
|
219 |
+
return pred_pose_enc_list
|
220 |
+
|
vggt/heads/dpt_head.py
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# linear head implementation for DUST3R
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import os
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from .head_act import activate_head
|
13 |
+
from .utils import normalized_view_plane_uv, HarmonicEmbedding, position_grid_to_embed
|
14 |
+
|
15 |
+
class DPTHead(nn.Module):
|
16 |
+
"""
|
17 |
+
"""
|
18 |
+
def __init__(self,
|
19 |
+
dim_in,
|
20 |
+
patch_size = 14,
|
21 |
+
output_dim = 4,
|
22 |
+
normalize_act="inv_log",
|
23 |
+
normalize_act_conf = "expp1",
|
24 |
+
features=256,
|
25 |
+
use_bn=False,
|
26 |
+
use_clstoken=False,
|
27 |
+
out_channels=[256, 512, 1024, 1024],
|
28 |
+
intermediate_layer_idx=[4, 11, 17, 23],
|
29 |
+
shared_norm = True,
|
30 |
+
add_rgb = False,
|
31 |
+
head_use_checkpoint=False,
|
32 |
+
groups=1,
|
33 |
+
shallow_conv=False,
|
34 |
+
load_da_str=None,
|
35 |
+
dpt_layer_norm=False,
|
36 |
+
pos_embed = False,
|
37 |
+
feature_only = False,
|
38 |
+
down_ratio = 1,
|
39 |
+
**kwargs,
|
40 |
+
):
|
41 |
+
super(DPTHead, self).__init__()
|
42 |
+
|
43 |
+
in_channels = dim_in
|
44 |
+
self.add_rgb = add_rgb
|
45 |
+
self.patch_size = patch_size
|
46 |
+
self.intermediate_layer_idx = intermediate_layer_idx
|
47 |
+
self.shared_norm = shared_norm
|
48 |
+
self.normalize_act = normalize_act
|
49 |
+
self.normalize_act_conf = normalize_act_conf
|
50 |
+
self.head_use_checkpoint = head_use_checkpoint
|
51 |
+
self.pos_embed = pos_embed
|
52 |
+
self.feature_only = feature_only
|
53 |
+
self.down_ratio = down_ratio
|
54 |
+
|
55 |
+
# if self.pos_embed:
|
56 |
+
# self.pose_embed_fn_64 = HarmonicEmbedding(n_harmonic_functions=64, omega_0=1.0, logspace=True, append_input=False)
|
57 |
+
# self.pose_embed_fn_128 = HarmonicEmbedding(n_harmonic_functions=128, omega_0=1.0, logspace=True, append_input=False)
|
58 |
+
# self.pose_embed_fn_256 = HarmonicEmbedding(n_harmonic_functions=256, omega_0=1.0, logspace=True, append_input=False)
|
59 |
+
# self.pose_embed_fn_512 = HarmonicEmbedding(n_harmonic_functions=512, omega_0=1.0, logspace=True, append_input=False)
|
60 |
+
# self.pose_embed_fn_1024 = HarmonicEmbedding(n_harmonic_functions=1024, omega_0=1.0, logspace=True, append_input=False)
|
61 |
+
|
62 |
+
if self.shared_norm:
|
63 |
+
self.norm = nn.LayerNorm(in_channels)
|
64 |
+
else:
|
65 |
+
self.norm = nn.ModuleList([nn.LayerNorm(in_channels) for _ in range(len(self.intermediate_layer_idx))])
|
66 |
+
|
67 |
+
self.use_clstoken = use_clstoken
|
68 |
+
|
69 |
+
self.projects = nn.ModuleList([
|
70 |
+
nn.Conv2d(
|
71 |
+
in_channels=in_channels,
|
72 |
+
out_channels=out_channel,
|
73 |
+
kernel_size=1,
|
74 |
+
stride=1,
|
75 |
+
padding=0,
|
76 |
+
) for out_channel in out_channels
|
77 |
+
])
|
78 |
+
|
79 |
+
self.resize_layers = nn.ModuleList([
|
80 |
+
nn.ConvTranspose2d(
|
81 |
+
in_channels=out_channels[0],
|
82 |
+
out_channels=out_channels[0],
|
83 |
+
kernel_size=4,
|
84 |
+
stride=4,
|
85 |
+
padding=0),
|
86 |
+
nn.ConvTranspose2d(
|
87 |
+
in_channels=out_channels[1],
|
88 |
+
out_channels=out_channels[1],
|
89 |
+
kernel_size=2,
|
90 |
+
stride=2,
|
91 |
+
padding=0),
|
92 |
+
nn.Identity(),
|
93 |
+
nn.Conv2d(
|
94 |
+
in_channels=out_channels[3],
|
95 |
+
out_channels=out_channels[3],
|
96 |
+
kernel_size=3,
|
97 |
+
stride=2,
|
98 |
+
padding=1)
|
99 |
+
])
|
100 |
+
|
101 |
+
if use_clstoken:
|
102 |
+
raise ValueError("CLS token is not supported for DPT head Now")
|
103 |
+
self.readout_projects = nn.ModuleList()
|
104 |
+
for _ in range(len(self.projects)):
|
105 |
+
self.readout_projects.append(
|
106 |
+
nn.Sequential(
|
107 |
+
nn.Linear(2 * in_channels, in_channels),
|
108 |
+
nn.GELU()))
|
109 |
+
|
110 |
+
self.scratch = _make_scratch(
|
111 |
+
out_channels,
|
112 |
+
features,
|
113 |
+
groups=1,
|
114 |
+
expand=False,
|
115 |
+
)
|
116 |
+
|
117 |
+
self.scratch.stem_transpose = None
|
118 |
+
|
119 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn, groups=groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
|
120 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn, groups=groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
|
121 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn, groups=groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
|
122 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn, has_residual=False, groups=groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
|
123 |
+
|
124 |
+
head_features_1 = features
|
125 |
+
head_features_2 = 32
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
|
130 |
+
if not self.feature_only:
|
131 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
132 |
+
conv2_in_channels = head_features_1 // 2 + 3 * int(self.add_rgb)
|
133 |
+
|
134 |
+
if dpt_layer_norm:
|
135 |
+
self.scratch.output_conv2 = nn.Sequential(
|
136 |
+
ChannelLayerNorm(conv2_in_channels),
|
137 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
138 |
+
nn.ReLU(True),
|
139 |
+
ChannelLayerNorm(head_features_2),
|
140 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
141 |
+
# nn.ReLU(True),
|
142 |
+
# nn.Identity(),
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
self.scratch.output_conv2 = nn.Sequential(
|
146 |
+
nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
|
147 |
+
nn.ReLU(True),
|
148 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
149 |
+
# nn.ReLU(True),
|
150 |
+
# nn.Identity(),
|
151 |
+
)
|
152 |
+
else:
|
153 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
|
154 |
+
|
155 |
+
|
156 |
+
|
157 |
+
if load_da_str is not None:
|
158 |
+
from off3d.utils.train_utils import remove_if_not_match
|
159 |
+
|
160 |
+
da_path = os.path.join(torch.hub.get_dir(), load_da_str)
|
161 |
+
da_model = torch.load(da_path)
|
162 |
+
to_load_dict = {}
|
163 |
+
for k in da_model.keys():
|
164 |
+
if "depth_head" in k:
|
165 |
+
to_load_dict[k.replace("depth_head.", "")] = da_model[k]
|
166 |
+
all_keys = list(to_load_dict.keys())
|
167 |
+
model_state_dict = self.state_dict()
|
168 |
+
for cur_key in all_keys:
|
169 |
+
to_load_dict = remove_if_not_match(model_state_dict, to_load_dict, cur_key)
|
170 |
+
|
171 |
+
missing, unexpected = self.load_state_dict(to_load_dict, strict=False)
|
172 |
+
|
173 |
+
print("Missing keys in DPT head: ", missing)
|
174 |
+
print("Unexpected keys in DPT head: ", unexpected)
|
175 |
+
for layer in self.scratch.output_conv2:
|
176 |
+
if isinstance(layer, (nn.Conv2d, nn.Linear)):
|
177 |
+
layer.weight.data *= 0.1
|
178 |
+
layer.bias.data *= 0.1
|
179 |
+
|
180 |
+
|
181 |
+
|
182 |
+
|
183 |
+
|
184 |
+
def forward(self, aggregated_tokens_list, batch, patch_start_idx):
|
185 |
+
|
186 |
+
B, _, _, H, W = batch["images"].shape
|
187 |
+
S = aggregated_tokens_list[0].shape[1]
|
188 |
+
|
189 |
+
patch_h, patch_w = H // self.patch_size, W // self.patch_size
|
190 |
+
|
191 |
+
# TODO use rgb as input for the DPT head
|
192 |
+
|
193 |
+
out = []
|
194 |
+
|
195 |
+
dpt_idx = 0
|
196 |
+
|
197 |
+
for layer_idx in self.intermediate_layer_idx:
|
198 |
+
if self.use_clstoken:
|
199 |
+
raise NotImplementedError("CLS token is not supported for DPT head Now")
|
200 |
+
x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
|
201 |
+
x = x.view(B*S, -1, x.shape[-1])
|
202 |
+
|
203 |
+
if self.shared_norm:
|
204 |
+
x = self.norm(x)
|
205 |
+
else:
|
206 |
+
x = self.norm[dpt_idx](x)
|
207 |
+
|
208 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
209 |
+
|
210 |
+
if self.head_use_checkpoint:
|
211 |
+
# e.g., from Bx2048xpatch_h*patch_w to Bx256xpatch_h*patch_w
|
212 |
+
x = torch.utils.checkpoint.checkpoint(self.projects[dpt_idx], x, use_reentrant=False)
|
213 |
+
if self.pos_embed:
|
214 |
+
x = self._apply_pos_embed(x, W, H)
|
215 |
+
x = torch.utils.checkpoint.checkpoint(self.resize_layers[dpt_idx], x, use_reentrant=False)
|
216 |
+
else:
|
217 |
+
x = self.projects[dpt_idx](x)
|
218 |
+
if self.pos_embed:
|
219 |
+
x = self._apply_pos_embed(x, W, H)
|
220 |
+
x = self.resize_layers[dpt_idx](x)
|
221 |
+
|
222 |
+
out.append(x)
|
223 |
+
dpt_idx += 1
|
224 |
+
|
225 |
+
if self.head_use_checkpoint:
|
226 |
+
out = torch.utils.checkpoint.checkpoint(self.scratch_forward, out, use_reentrant=False)
|
227 |
+
else:
|
228 |
+
out = self.scratch_forward(out)
|
229 |
+
|
230 |
+
# out = F.interpolate(out, (int(patch_h * self.patch_size), int(patch_w * self.patch_size)), mode="bilinear", align_corners=True)
|
231 |
+
out = custom_interpolate(out, (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)), mode="bilinear", align_corners=True)
|
232 |
+
|
233 |
+
if self.pos_embed:
|
234 |
+
out = self._apply_pos_embed(out, W, H)
|
235 |
+
|
236 |
+
if self.feature_only:
|
237 |
+
return out
|
238 |
+
|
239 |
+
|
240 |
+
if self.add_rgb:
|
241 |
+
# NOTE batch["images"] is in the range of [0, 1]
|
242 |
+
out = torch.cat([out, batch["images"].view(B*S, 3, H, W).clip(0, 1)], dim=1)
|
243 |
+
|
244 |
+
|
245 |
+
if self.head_use_checkpoint:
|
246 |
+
out = torch.utils.checkpoint.checkpoint(self.scratch.output_conv2, out, use_reentrant=False)
|
247 |
+
else:
|
248 |
+
out = self.scratch.output_conv2(out)
|
249 |
+
|
250 |
+
preds, conf = activate_head(out, normalize_act=self.normalize_act, normalize_act_conf=self.normalize_act_conf)
|
251 |
+
|
252 |
+
# back to B, S
|
253 |
+
# B, S, H, W, 3
|
254 |
+
preds = preds.view(B, S, *preds.shape[1:])
|
255 |
+
# B, S, H, W
|
256 |
+
conf = conf.view(B, S, *conf.shape[1:])
|
257 |
+
|
258 |
+
return preds, conf
|
259 |
+
|
260 |
+
|
261 |
+
def _apply_pos_embed(self, x, W, H, ratio=0.1):
|
262 |
+
"""Apply positional embedding to the input tensor."""
|
263 |
+
patch_w = x.shape[-1]
|
264 |
+
patch_h = x.shape[-2]
|
265 |
+
|
266 |
+
pos_embed = normalized_view_plane_uv(patch_w, patch_h, aspect_ratio=W/H, dtype=x.dtype, device=x.device)
|
267 |
+
|
268 |
+
pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
|
269 |
+
pos_embed = pos_embed * ratio
|
270 |
+
pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
271 |
+
return x + pos_embed
|
272 |
+
|
273 |
+
|
274 |
+
def scratch_forward(self, out):
|
275 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
276 |
+
|
277 |
+
|
278 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1) # layer_1:[32, 256, 148, 148]
|
279 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2) # layer_2:[32, 512, 74, 74]
|
280 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3) # layer_3:[32, 1024, 37, 37]
|
281 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4) # layer_4:[32, 1024, 19, 19]
|
282 |
+
|
283 |
+
out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
284 |
+
del layer_4_rn, layer_4
|
285 |
+
|
286 |
+
out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
|
287 |
+
del layer_3_rn, layer_3
|
288 |
+
|
289 |
+
out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
|
290 |
+
del layer_2_rn, layer_2
|
291 |
+
|
292 |
+
out = self.scratch.refinenet1(out, layer_1_rn)
|
293 |
+
del layer_1_rn, layer_1
|
294 |
+
|
295 |
+
out = self.scratch.output_conv1(out)
|
296 |
+
return out
|
297 |
+
|
298 |
+
|
299 |
+
|
300 |
+
|
301 |
+
|
302 |
+
################################################################################
|
303 |
+
|
304 |
+
# Modules
|
305 |
+
|
306 |
+
|
307 |
+
|
308 |
+
def _make_fusion_block(features, use_bn, size=None, has_residual=True, groups=1, shallow_conv=False, dpt_layer_norm=False):
|
309 |
+
return FeatureFusionBlock(
|
310 |
+
features,
|
311 |
+
nn.ReLU(True),
|
312 |
+
deconv=False,
|
313 |
+
bn=use_bn,
|
314 |
+
expand=False,
|
315 |
+
align_corners=True,
|
316 |
+
size=size,
|
317 |
+
has_residual=has_residual,
|
318 |
+
groups=groups,
|
319 |
+
shallow_conv=shallow_conv,
|
320 |
+
dpt_layer_norm=dpt_layer_norm,
|
321 |
+
)
|
322 |
+
|
323 |
+
|
324 |
+
|
325 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
326 |
+
scratch = nn.Module()
|
327 |
+
|
328 |
+
out_shape1 = out_shape
|
329 |
+
out_shape2 = out_shape
|
330 |
+
out_shape3 = out_shape
|
331 |
+
if len(in_shape) >= 4:
|
332 |
+
out_shape4 = out_shape
|
333 |
+
|
334 |
+
if expand:
|
335 |
+
out_shape1 = out_shape
|
336 |
+
out_shape2 = out_shape * 2
|
337 |
+
out_shape3 = out_shape * 4
|
338 |
+
if len(in_shape) >= 4:
|
339 |
+
out_shape4 = out_shape * 8
|
340 |
+
|
341 |
+
scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
342 |
+
scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
343 |
+
scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
344 |
+
if len(in_shape) >= 4:
|
345 |
+
scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
346 |
+
|
347 |
+
return scratch
|
348 |
+
|
349 |
+
|
350 |
+
|
351 |
+
|
352 |
+
class ResidualConvUnit(nn.Module):
|
353 |
+
"""Residual convolution module.
|
354 |
+
"""
|
355 |
+
|
356 |
+
def __init__(self, features, activation, bn, groups=1, shallow_conv=False, dpt_layer_norm=False):
|
357 |
+
"""Init.
|
358 |
+
|
359 |
+
Args:
|
360 |
+
features (int): number of features
|
361 |
+
"""
|
362 |
+
super().__init__()
|
363 |
+
|
364 |
+
self.bn = bn
|
365 |
+
|
366 |
+
self.groups=groups
|
367 |
+
|
368 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
369 |
+
|
370 |
+
self.shallow_conv = shallow_conv
|
371 |
+
if not self.shallow_conv:
|
372 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
373 |
+
|
374 |
+
# if self.bn == True:
|
375 |
+
# self.bn1 = nn.BatchNorm2d(features)
|
376 |
+
# self.bn2 = nn.BatchNorm2d(features)
|
377 |
+
# elif dpt_layer_norm == :
|
378 |
+
|
379 |
+
if dpt_layer_norm:
|
380 |
+
self.norm1 = ChannelLayerNorm(features)
|
381 |
+
self.norm2 = ChannelLayerNorm(features)
|
382 |
+
else:
|
383 |
+
self.norm1 = None
|
384 |
+
self.norm2 = None
|
385 |
+
|
386 |
+
self.activation = activation
|
387 |
+
|
388 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
389 |
+
|
390 |
+
def forward(self, x):
|
391 |
+
"""Forward pass.
|
392 |
+
|
393 |
+
Args:
|
394 |
+
x (tensor): input
|
395 |
+
|
396 |
+
Returns:
|
397 |
+
tensor: output
|
398 |
+
"""
|
399 |
+
|
400 |
+
out = self.activation(x)
|
401 |
+
out = self.conv1(out)
|
402 |
+
if self.norm1 is not None:
|
403 |
+
out = self.norm1(out)
|
404 |
+
|
405 |
+
if not self.shallow_conv:
|
406 |
+
out = self.activation(out)
|
407 |
+
out = self.conv2(out)
|
408 |
+
if self.norm2 is not None:
|
409 |
+
out = self.norm2(out)
|
410 |
+
|
411 |
+
# if self.groups > 1:
|
412 |
+
# out = self.conv_merge(out)
|
413 |
+
|
414 |
+
return self.skip_add.add(out, x)
|
415 |
+
|
416 |
+
|
417 |
+
class FeatureFusionBlock(nn.Module):
|
418 |
+
"""Feature fusion block.
|
419 |
+
"""
|
420 |
+
|
421 |
+
def __init__(
|
422 |
+
self,
|
423 |
+
features,
|
424 |
+
activation,
|
425 |
+
deconv=False,
|
426 |
+
bn=False,
|
427 |
+
expand=False,
|
428 |
+
align_corners=True,
|
429 |
+
size=None,
|
430 |
+
has_residual=True,
|
431 |
+
groups=1,
|
432 |
+
shallow_conv=False,
|
433 |
+
dpt_layer_norm=False,
|
434 |
+
):
|
435 |
+
"""Init.
|
436 |
+
|
437 |
+
Args:
|
438 |
+
features (int): number of features
|
439 |
+
"""
|
440 |
+
super(FeatureFusionBlock, self).__init__()
|
441 |
+
|
442 |
+
self.deconv = deconv
|
443 |
+
self.align_corners = align_corners
|
444 |
+
|
445 |
+
self.groups=groups
|
446 |
+
|
447 |
+
self.expand = expand
|
448 |
+
out_features = features
|
449 |
+
if self.expand == True:
|
450 |
+
out_features = features // 2
|
451 |
+
|
452 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups)
|
453 |
+
|
454 |
+
if has_residual:
|
455 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
|
456 |
+
|
457 |
+
self.has_residual = has_residual
|
458 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups, shallow_conv=shallow_conv, dpt_layer_norm=dpt_layer_norm)
|
459 |
+
|
460 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
461 |
+
|
462 |
+
self.size=size
|
463 |
+
|
464 |
+
def forward(self, *xs, size=None):
|
465 |
+
"""Forward pass.
|
466 |
+
|
467 |
+
Returns:
|
468 |
+
tensor: output
|
469 |
+
"""
|
470 |
+
output = xs[0]
|
471 |
+
|
472 |
+
if self.has_residual:
|
473 |
+
res = self.resConfUnit1(xs[1])
|
474 |
+
output = self.skip_add.add(output, res)
|
475 |
+
|
476 |
+
output = self.resConfUnit2(output)
|
477 |
+
|
478 |
+
if (size is None) and (self.size is None):
|
479 |
+
modifier = {"scale_factor": 2}
|
480 |
+
elif size is None:
|
481 |
+
modifier = {"size": self.size}
|
482 |
+
else:
|
483 |
+
modifier = {"size": size}
|
484 |
+
|
485 |
+
# output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
486 |
+
output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
487 |
+
|
488 |
+
output = self.out_conv(output)
|
489 |
+
|
490 |
+
return output
|
491 |
+
|
492 |
+
|
493 |
+
|
494 |
+
def custom_interpolate(x, size=None, scale_factor=None, mode="bilinear", align_corners=True):
|
495 |
+
if size is None:
|
496 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
497 |
+
INT_MAX = 1610612736
|
498 |
+
|
499 |
+
input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
|
500 |
+
|
501 |
+
if input_elements > INT_MAX:
|
502 |
+
# Split x into chunks along the batch dimension
|
503 |
+
chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
|
504 |
+
interpolated_chunks = [nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks]
|
505 |
+
x = torch.cat(interpolated_chunks, dim=0)
|
506 |
+
return x.contiguous()
|
507 |
+
else:
|
508 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
509 |
+
|
510 |
+
|
511 |
+
class ChannelLayerNorm(nn.Module):
|
512 |
+
def __init__(self, num_channels):
|
513 |
+
super().__init__()
|
514 |
+
self.ln = nn.LayerNorm(num_channels)
|
515 |
+
|
516 |
+
def forward(self, x):
|
517 |
+
# x: [N, C, H, W]
|
518 |
+
x = x.permute(0, 2, 3, 1) # -> [N, H, W, C]
|
519 |
+
x = self.ln(x) # now LN sees 'C' as the last dimension
|
520 |
+
x = x.permute(0, 3, 1, 2) # -> [N, C, H, W]
|
521 |
+
return x
|
vggt/heads/head_act.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
|
4 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
5 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
6 |
+
#
|
7 |
+
# --------------------------------------------------------
|
8 |
+
# post process function for all heads: extract 3D points/confidence from output
|
9 |
+
# --------------------------------------------------------
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
|
17 |
+
T = pred_pose_enc[..., :3]
|
18 |
+
quat = pred_pose_enc[..., 3:7]
|
19 |
+
fl = pred_pose_enc[..., 7:] # or fov
|
20 |
+
|
21 |
+
T = base_pose_act(T, trans_act)
|
22 |
+
quat = base_pose_act(quat, quat_act)
|
23 |
+
fl = base_pose_act(fl, fl_act) # or fov
|
24 |
+
|
25 |
+
pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
|
26 |
+
|
27 |
+
return pred_pose_enc
|
28 |
+
|
29 |
+
|
30 |
+
def base_pose_act(pose_enc, act_type="linear"):
|
31 |
+
if act_type == "linear":
|
32 |
+
return pose_enc
|
33 |
+
elif act_type == "inv_log":
|
34 |
+
return inverse_log_transform(pose_enc)
|
35 |
+
elif act_type == "exp":
|
36 |
+
return torch.exp(pose_enc)
|
37 |
+
elif act_type == "relu":
|
38 |
+
return F.relu(pose_enc)
|
39 |
+
else:
|
40 |
+
raise ValueError(f"Unknown act_type: {act_type}")
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
def activate_head(out, normalize_act="norm_exp", normalize_act_conf="expp1"):
|
45 |
+
"""
|
46 |
+
"""
|
47 |
+
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
48 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W, C expected
|
49 |
+
|
50 |
+
# Split into xyz (first C-1 channels) and confidence (last channel)
|
51 |
+
xyz = fmap[:, :, :, :-1]
|
52 |
+
conf = fmap[:, :, :, -1]
|
53 |
+
|
54 |
+
if normalize_act == "norm_exp":
|
55 |
+
# 1) distance d = ||xyz||
|
56 |
+
# 2) normalize xyz => xyz / d
|
57 |
+
# 3) multiply by torch.expm1(d)
|
58 |
+
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
59 |
+
xyz_normed = xyz / d
|
60 |
+
pts3d = xyz_normed * torch.expm1(d)
|
61 |
+
elif normalize_act == "norm":
|
62 |
+
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
63 |
+
elif normalize_act == "exp":
|
64 |
+
pts3d = torch.exp(xyz)
|
65 |
+
elif normalize_act == "relu":
|
66 |
+
pts3d = F.relu(xyz)
|
67 |
+
elif normalize_act == "inv_log":
|
68 |
+
pts3d = inverse_log_transform(xyz)
|
69 |
+
elif normalize_act == "xy_inv_log":
|
70 |
+
xy, z = xyz.split([2, 1], dim=-1)
|
71 |
+
z = inverse_log_transform(z)
|
72 |
+
pts3d = torch.cat([xy * z, z], dim=-1)
|
73 |
+
elif normalize_act == "sigmoid":
|
74 |
+
pts3d = torch.sigmoid(xyz)
|
75 |
+
elif normalize_act == "linear":
|
76 |
+
pts3d = xyz
|
77 |
+
else:
|
78 |
+
raise ValueError(f"Unknown normalize_act: {normalize_act}")
|
79 |
+
|
80 |
+
# reg_dense_conf for mode='exp', with vmin=1, vmax=inf
|
81 |
+
# => conf_out = 1 + e^(conf)
|
82 |
+
# (since clip(max=vmax - vmin) with vmax=inf basically doesn’t limit anything)
|
83 |
+
if normalize_act_conf == "expp1":
|
84 |
+
conf_out = 1 + conf.exp()
|
85 |
+
elif normalize_act_conf == "expp0":
|
86 |
+
conf_out = conf.exp()
|
87 |
+
elif normalize_act_conf == "sigmoid":
|
88 |
+
conf_out = torch.sigmoid(conf)
|
89 |
+
else:
|
90 |
+
raise ValueError(f"Unknown normalize_act_conf: {normalize_act_conf}")
|
91 |
+
|
92 |
+
# Final dictionary
|
93 |
+
return pts3d, conf_out
|
94 |
+
|
95 |
+
|
96 |
+
def inverse_log_transform(y):
|
97 |
+
return torch.sign(y) * (torch.expm1(torch.abs(y)))
|
vggt/heads/track_head.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
2 |
+
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
3 |
+
#
|
4 |
+
# --------------------------------------------------------
|
5 |
+
# linear head implementation for DUST3R
|
6 |
+
# --------------------------------------------------------
|
7 |
+
|
8 |
+
import os
|
9 |
+
import random
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from .head_act import activate_head
|
14 |
+
from .utils import normalized_view_plane_uv, HarmonicEmbedding, position_grid_to_embed
|
15 |
+
from .dpt_head import DPTHead
|
16 |
+
from .match_head import MatchHead
|
17 |
+
from ..track_modules.base_track_predictor import BaseTrackerPredictor
|
18 |
+
from ..track_modules.base_track_predictor_v2 import BaseTrackerPredictorV2
|
19 |
+
|
20 |
+
EPS = 1e-6
|
21 |
+
|
22 |
+
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
|
23 |
+
# x and mask are the same shape, or at least broadcastably so
|
24 |
+
# returns shape-1
|
25 |
+
# axis can be a list of axes
|
26 |
+
for a, b in zip(x.size(), mask.size()):
|
27 |
+
assert a == b # some shape mismatch!
|
28 |
+
prod = x * mask
|
29 |
+
if dim is None:
|
30 |
+
numer = torch.sum(prod)
|
31 |
+
denom = EPS + torch.sum(mask)
|
32 |
+
else:
|
33 |
+
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
34 |
+
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
|
35 |
+
|
36 |
+
mean = numer / denom
|
37 |
+
return mean
|
38 |
+
|
39 |
+
def balanced_ce_loss(pred, gt, valid=None):
|
40 |
+
"""Balanced cross entropy loss.
|
41 |
+
pred: predicted scores
|
42 |
+
gt: binary ground truth
|
43 |
+
valid: validity mask
|
44 |
+
"""
|
45 |
+
# pred and gt are the same shape
|
46 |
+
for a, b in zip(pred.size(), gt.size()):
|
47 |
+
assert a == b # some shape mismatch!
|
48 |
+
if valid is not None:
|
49 |
+
for a, b in zip(pred.size(), valid.size()):
|
50 |
+
assert a == b # some shape mismatch!
|
51 |
+
else:
|
52 |
+
valid = torch.ones_like(gt)
|
53 |
+
|
54 |
+
pos = (gt > 0.95).float()
|
55 |
+
neg = (gt < 0.05).float()
|
56 |
+
|
57 |
+
label = pos * 2.0 - 1.0
|
58 |
+
a = -label * pred
|
59 |
+
b = F.relu(a)
|
60 |
+
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
61 |
+
|
62 |
+
pos_loss = reduce_masked_mean(loss, pos * valid)
|
63 |
+
neg_loss = reduce_masked_mean(loss, neg * valid)
|
64 |
+
|
65 |
+
balanced_loss = pos_loss + neg_loss
|
66 |
+
|
67 |
+
return balanced_loss, loss
|
68 |
+
|
69 |
+
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8, vis_aware=False, huber=False, delta=10, vis_aware_w=0.1, **kwargs):
|
70 |
+
"""Loss function defined over sequence of flow predictions"""
|
71 |
+
B, S, N, D = flow_gt.shape
|
72 |
+
assert D == 2
|
73 |
+
B, S1, N = vis.shape
|
74 |
+
B, S2, N = valids.shape
|
75 |
+
assert S == S1
|
76 |
+
assert S == S2
|
77 |
+
n_predictions = len(flow_preds)
|
78 |
+
flow_loss = 0.0
|
79 |
+
|
80 |
+
for i in range(n_predictions):
|
81 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
82 |
+
flow_pred = flow_preds[i]
|
83 |
+
|
84 |
+
i_loss = (flow_pred - flow_gt).abs() # B, S, N, 2
|
85 |
+
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
86 |
+
|
87 |
+
# Combine valids and vis for per-frame valid masking.
|
88 |
+
combined_mask = torch.logical_and(valids, vis)
|
89 |
+
|
90 |
+
# valids * vis.float() # B, S, N
|
91 |
+
|
92 |
+
# vis_aware weighting. Apply BEFORE reduce_masked_mean
|
93 |
+
|
94 |
+
if vis_aware:
|
95 |
+
combined_mask = combined_mask.float() * (1.0 + vis_aware_w) # Add, don't add to the mask itself.
|
96 |
+
# combined_mask = torch.clamp(combined_mask, 0.0, 1.0) # No need to clamp.
|
97 |
+
# Apply the mask *before* taking the mean.
|
98 |
+
# i_loss = i_loss * combined_mask
|
99 |
+
# flow_loss += i_weight * i_loss.mean()
|
100 |
+
flow_loss += i_weight * reduce_masked_mean(i_loss, combined_mask)
|
101 |
+
else:
|
102 |
+
if combined_mask.numel() > 10:
|
103 |
+
# flow_loss += i_weight * i_loss.mean()
|
104 |
+
i_loss = i_loss[combined_mask]
|
105 |
+
flow_loss += i_weight * i_loss.mean()
|
106 |
+
else:
|
107 |
+
flow_loss += 0
|
108 |
+
|
109 |
+
# # Handle the case where no points are valid.
|
110 |
+
# if combined_mask.sum() > 0:
|
111 |
+
# flow_loss += i_weight * reduce_masked_mean(i_loss, combined_mask) # Pass combined_mask
|
112 |
+
# else: No valid points, so this term contributes 0 to the loss.
|
113 |
+
# flow_loss += 0. (This is implicit)
|
114 |
+
|
115 |
+
# Avoid division by zero if n_predictions is 0 (though it shouldn't be).
|
116 |
+
if n_predictions > 0:
|
117 |
+
flow_loss = flow_loss / n_predictions
|
118 |
+
|
119 |
+
return flow_loss
|
120 |
+
|
121 |
+
class TrackHead(nn.Module):
|
122 |
+
"""
|
123 |
+
Track head that uses DPT/Match head to process tokens and BaseTrackerPredictor for tracking.
|
124 |
+
"""
|
125 |
+
def __init__(self,
|
126 |
+
dim_in,
|
127 |
+
patch_size=16,
|
128 |
+
features=128,
|
129 |
+
feature_extractor_type="dpt", # or "match"
|
130 |
+
train_query_points=128,
|
131 |
+
feature_extractor_kwargs={},
|
132 |
+
tracker_kwargs={},
|
133 |
+
loss_kwargs={},
|
134 |
+
iters=4,
|
135 |
+
use_base_tracker_v2=False,
|
136 |
+
predict_conf=False,
|
137 |
+
random_query_points = None,
|
138 |
+
**kwargs):
|
139 |
+
super().__init__()
|
140 |
+
|
141 |
+
self.patch_size = patch_size
|
142 |
+
self.feature_extractor_type = feature_extractor_type
|
143 |
+
self.train_query_points = train_query_points
|
144 |
+
self.random_query_points = random_query_points
|
145 |
+
|
146 |
+
# Initialize feature extractor (DPT or Match head)
|
147 |
+
if feature_extractor_type == "dpt":
|
148 |
+
self.feature_extractor = DPTHead(
|
149 |
+
dim_in=dim_in,
|
150 |
+
patch_size=patch_size,
|
151 |
+
features=features,
|
152 |
+
feature_only=True, # Only output features, no activation
|
153 |
+
**feature_extractor_kwargs
|
154 |
+
)
|
155 |
+
elif feature_extractor_type == "match":
|
156 |
+
raise NotImplementedError("Match head is not implemented for track head")
|
157 |
+
self.feature_extractor = MatchHead(
|
158 |
+
dim_in=dim_in,
|
159 |
+
patch_size=patch_size,
|
160 |
+
features=features,
|
161 |
+
**feature_extractor_kwargs
|
162 |
+
)
|
163 |
+
else:
|
164 |
+
raise ValueError(f"Unknown feature_extractor_type: {feature_extractor_type}")
|
165 |
+
|
166 |
+
# Initialize tracker
|
167 |
+
if use_base_tracker_v2:
|
168 |
+
self.tracker = BaseTrackerPredictorV2(
|
169 |
+
latent_dim=features, # Match the output_dim of feature extractor
|
170 |
+
predict_conf=predict_conf,
|
171 |
+
**tracker_kwargs
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
self.tracker = BaseTrackerPredictor(
|
175 |
+
latent_dim=features, # Match the output_dim of feature extractor
|
176 |
+
predict_conf=predict_conf,
|
177 |
+
**tracker_kwargs
|
178 |
+
)
|
179 |
+
|
180 |
+
self.loss_kwargs = loss_kwargs
|
181 |
+
self.iters = iters
|
182 |
+
|
183 |
+
|
184 |
+
def _compute_losses(self, coord_preds, vis_scores, conf_scores, batch):
|
185 |
+
"""Compute tracking losses using sequence_loss"""
|
186 |
+
gt_tracks = batch["tracks"] # B, S, N, 2
|
187 |
+
gt_track_vis_mask = batch["track_vis_mask"] # B, S, N
|
188 |
+
|
189 |
+
# if self.training and hasattr(self, "train_query_points"):
|
190 |
+
train_query_points = coord_preds[-1].shape[2]
|
191 |
+
gt_tracks = gt_tracks[:, :, :train_query_points]
|
192 |
+
gt_track_vis_mask = gt_track_vis_mask[:, :, :train_query_points]
|
193 |
+
|
194 |
+
# Create validity mask that filters out tracks not visible in first frame
|
195 |
+
valids = torch.ones_like(gt_track_vis_mask)
|
196 |
+
mask = gt_track_vis_mask[:, 0, :] == True
|
197 |
+
valids = valids * mask.unsqueeze(1)
|
198 |
+
|
199 |
+
# Compute tracking loss using sequence_loss
|
200 |
+
track_loss = sequence_loss(
|
201 |
+
flow_preds=coord_preds,
|
202 |
+
flow_gt=gt_tracks,
|
203 |
+
vis=gt_track_vis_mask,
|
204 |
+
valids=valids,
|
205 |
+
**self.loss_kwargs
|
206 |
+
)
|
207 |
+
|
208 |
+
vis_loss = F.binary_cross_entropy_with_logits(vis_scores[valids], gt_track_vis_mask[valids].float())
|
209 |
+
# within 3 pixels
|
210 |
+
if conf_scores is not None:
|
211 |
+
gt_conf_mask = (gt_tracks - coord_preds[-1]).norm(dim=-1) < 3
|
212 |
+
conf_loss = F.binary_cross_entropy_with_logits(conf_scores[valids], gt_conf_mask[valids].float())
|
213 |
+
else:
|
214 |
+
conf_loss = 0
|
215 |
+
|
216 |
+
return track_loss, vis_loss, conf_loss
|
217 |
+
|
218 |
+
def forward(self, aggregated_tokens_list, batch, patch_start_idx):
|
219 |
+
B, S, _, H, W = batch["images"].shape
|
220 |
+
|
221 |
+
gt_tracks = batch["tracks"] # B, S, N, 2
|
222 |
+
# gt_track_vis_mask = batch["track_vis_mask"] # B, S, N
|
223 |
+
|
224 |
+
# Extract features using DPT/Match head
|
225 |
+
if self.feature_extractor_type == "dpt":
|
226 |
+
feature_maps = self.feature_extractor(aggregated_tokens_list, batch, patch_start_idx)
|
227 |
+
else: # match head
|
228 |
+
feature_maps = self.feature_extractor(aggregated_tokens_list, batch, patch_start_idx)["descriptor"]
|
229 |
+
|
230 |
+
feature_maps = feature_maps.view(B, S, *feature_maps.shape[1:]).clone()
|
231 |
+
# Get query points from batch
|
232 |
+
|
233 |
+
query_points = gt_tracks[:, 0] # Use first frame's points as query
|
234 |
+
|
235 |
+
if self.training:
|
236 |
+
if self.random_query_points is not None:
|
237 |
+
min_val = self.random_query_points[0]
|
238 |
+
max_val = self.random_query_points[1]
|
239 |
+
mu = max_val # Mean centered at the upper bound
|
240 |
+
sigma = (max_val - min_val) / 2.71 # Standard deviation, exp
|
241 |
+
train_query_points = int(random.gauss(mu, sigma))
|
242 |
+
train_query_points = max(min(train_query_points, max_val), min_val) # Clamp to ensure value is within range
|
243 |
+
else:
|
244 |
+
train_query_points = self.train_query_points
|
245 |
+
query_points = query_points[:, :train_query_points]
|
246 |
+
|
247 |
+
# Predict tracks using BaseTrackerPredictor
|
248 |
+
# coord_preds: a list of B, S, N, 2
|
249 |
+
# vis_scores: B, S, N
|
250 |
+
coord_preds, vis_scores, conf_scores = self.tracker(
|
251 |
+
query_points=query_points,
|
252 |
+
fmaps=feature_maps,
|
253 |
+
iters=self.iters,
|
254 |
+
)
|
255 |
+
|
256 |
+
# Calculate losses if in training mode
|
257 |
+
track_loss, vis_loss, conf_loss = self._compute_losses(coord_preds, vis_scores, conf_scores, batch)
|
258 |
+
|
259 |
+
loss_dict = {
|
260 |
+
"loss_track": track_loss,
|
261 |
+
"loss_vis": vis_loss,
|
262 |
+
"loss_track_conf": conf_loss,
|
263 |
+
"last_track_pred": coord_preds[-1],
|
264 |
+
}
|
265 |
+
return loss_dict
|
266 |
+
|
267 |
+
|
vggt/heads/utils.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
|
7 |
+
|
8 |
+
|
9 |
+
def make_sincos_pos_embed(
|
10 |
+
embed_dim: int, pos: torch.Tensor, omega_0: float = 100
|
11 |
+
) -> torch.Tensor:
|
12 |
+
"""
|
13 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
- embed_dim: The embedding dimension.
|
17 |
+
- pos: The position to generate the embedding from.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
- emb: The generated 1D positional embedding.
|
21 |
+
"""
|
22 |
+
assert embed_dim % 2 == 0
|
23 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
|
24 |
+
omega /= embed_dim / 2.0
|
25 |
+
omega = 1.0 / omega_0**omega # (D/2,)
|
26 |
+
|
27 |
+
pos = pos.reshape(-1) # (M,)
|
28 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
29 |
+
|
30 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
31 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
32 |
+
|
33 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
34 |
+
return emb.float()
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
|
39 |
+
"""
|
40 |
+
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
41 |
+
|
42 |
+
Args:
|
43 |
+
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
44 |
+
embed_dim: Output channel dimension for embeddings
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Tensor of shape (H, W, embed_dim) with positional embeddings
|
48 |
+
"""
|
49 |
+
H, W, grid_dim = pos_grid.shape
|
50 |
+
assert grid_dim == 2
|
51 |
+
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
52 |
+
|
53 |
+
# Process x and y coordinates separately
|
54 |
+
emb_x = make_sincos_pos_embed(embed_dim//2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
|
55 |
+
emb_y = make_sincos_pos_embed(embed_dim//2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
|
56 |
+
|
57 |
+
# Combine and reshape
|
58 |
+
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
|
59 |
+
|
60 |
+
return emb.view(H, W, embed_dim) # [H, W, D]
|
61 |
+
|
62 |
+
|
63 |
+
class HarmonicEmbedding(torch.nn.Module):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
n_harmonic_functions: int = 6,
|
67 |
+
omega_0: float = 1.0,
|
68 |
+
logspace: bool = True,
|
69 |
+
append_input: bool = True,
|
70 |
+
) -> None:
|
71 |
+
"""
|
72 |
+
The harmonic embedding layer supports the classical
|
73 |
+
Nerf positional encoding described in
|
74 |
+
`NeRF <https://arxiv.org/abs/2003.08934>`_
|
75 |
+
and the integrated position encoding in
|
76 |
+
`MIP-NeRF <https://arxiv.org/abs/2103.13415>`_.
|
77 |
+
|
78 |
+
During the inference you can provide the extra argument `diag_cov`.
|
79 |
+
|
80 |
+
If `diag_cov is None`, it converts
|
81 |
+
rays parametrized with a `ray_bundle` to 3D points by
|
82 |
+
extending each ray according to the corresponding length.
|
83 |
+
Then it converts each feature
|
84 |
+
(i.e. vector along the last dimension) in `x`
|
85 |
+
into a series of harmonic features `embedding`,
|
86 |
+
where for each i in range(dim) the following are present
|
87 |
+
in embedding[...]::
|
88 |
+
|
89 |
+
[
|
90 |
+
sin(f_1*x[..., i]),
|
91 |
+
sin(f_2*x[..., i]),
|
92 |
+
...
|
93 |
+
sin(f_N * x[..., i]),
|
94 |
+
cos(f_1*x[..., i]),
|
95 |
+
cos(f_2*x[..., i]),
|
96 |
+
...
|
97 |
+
cos(f_N * x[..., i]),
|
98 |
+
x[..., i], # only present if append_input is True.
|
99 |
+
]
|
100 |
+
|
101 |
+
where N corresponds to `n_harmonic_functions-1`, and f_i is a scalar
|
102 |
+
denoting the i-th frequency of the harmonic embedding.
|
103 |
+
|
104 |
+
|
105 |
+
If `diag_cov is not None`, it approximates
|
106 |
+
conical frustums following a ray bundle as gaussians,
|
107 |
+
defined by x, the means of the gaussians and diag_cov,
|
108 |
+
the diagonal covariances.
|
109 |
+
Then it converts each gaussian
|
110 |
+
into a series of harmonic features `embedding`,
|
111 |
+
where for each i in range(dim) the following are present
|
112 |
+
in embedding[...]::
|
113 |
+
|
114 |
+
[
|
115 |
+
sin(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
|
116 |
+
sin(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),
|
117 |
+
...
|
118 |
+
sin(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
|
119 |
+
cos(f_1*x[..., i]) * exp(0.5 * f_1**2 * diag_cov[..., i,]),
|
120 |
+
cos(f_2*x[..., i]) * exp(0.5 * f_2**2 * diag_cov[..., i,]),,
|
121 |
+
...
|
122 |
+
cos(f_N * x[..., i]) * exp(0.5 * f_N**2 * diag_cov[..., i,]),
|
123 |
+
x[..., i], # only present if append_input is True.
|
124 |
+
]
|
125 |
+
|
126 |
+
where N equals `n_harmonic_functions-1`, and f_i is a scalar
|
127 |
+
denoting the i-th frequency of the harmonic embedding.
|
128 |
+
|
129 |
+
If `logspace==True`, the frequencies `[f_1, ..., f_N]` are
|
130 |
+
powers of 2:
|
131 |
+
`f_1, ..., f_N = 2**torch.arange(n_harmonic_functions)`
|
132 |
+
|
133 |
+
If `logspace==False`, frequencies are linearly spaced between
|
134 |
+
`1.0` and `2**(n_harmonic_functions-1)`:
|
135 |
+
`f_1, ..., f_N = torch.linspace(
|
136 |
+
1.0, 2**(n_harmonic_functions-1), n_harmonic_functions
|
137 |
+
)`
|
138 |
+
|
139 |
+
Note that `x` is also premultiplied by the base frequency `omega_0`
|
140 |
+
before evaluating the harmonic functions.
|
141 |
+
|
142 |
+
Args:
|
143 |
+
n_harmonic_functions: int, number of harmonic
|
144 |
+
features
|
145 |
+
omega_0: float, base frequency
|
146 |
+
logspace: bool, Whether to space the frequencies in
|
147 |
+
logspace or linear space
|
148 |
+
append_input: bool, whether to concat the original
|
149 |
+
input to the harmonic embedding. If true the
|
150 |
+
output is of the form (embed.sin(), embed.cos(), x)
|
151 |
+
"""
|
152 |
+
super().__init__()
|
153 |
+
|
154 |
+
if logspace:
|
155 |
+
frequencies = 2.0 ** torch.arange(
|
156 |
+
n_harmonic_functions, dtype=torch.float32
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
frequencies = torch.linspace(
|
160 |
+
1.0,
|
161 |
+
2.0 ** (n_harmonic_functions - 1),
|
162 |
+
n_harmonic_functions,
|
163 |
+
dtype=torch.float32,
|
164 |
+
)
|
165 |
+
|
166 |
+
self.register_buffer(
|
167 |
+
"_frequencies", frequencies * omega_0, persistent=False
|
168 |
+
)
|
169 |
+
self.register_buffer(
|
170 |
+
"_zero_half_pi",
|
171 |
+
torch.tensor([0.0, 0.5 * torch.pi]),
|
172 |
+
persistent=False,
|
173 |
+
)
|
174 |
+
self.append_input = append_input
|
175 |
+
|
176 |
+
def forward(
|
177 |
+
self, x: torch.Tensor, diag_cov: Optional[torch.Tensor] = None, **kwargs
|
178 |
+
) -> torch.Tensor:
|
179 |
+
"""
|
180 |
+
Args:
|
181 |
+
x: tensor of shape [..., dim]
|
182 |
+
diag_cov: An optional tensor of shape `(..., dim)`
|
183 |
+
representing the diagonal covariance matrices of our Gaussians, joined with x
|
184 |
+
as means of the Gaussians.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
embedding: a harmonic embedding of `x` of shape
|
188 |
+
[..., (n_harmonic_functions * 2 + int(append_input)) * num_points_per_ray]
|
189 |
+
"""
|
190 |
+
# [..., dim, n_harmonic_functions]
|
191 |
+
embed = x[..., None] * self._frequencies
|
192 |
+
# [..., 1, dim, n_harmonic_functions] + [2, 1, 1] => [..., 2, dim, n_harmonic_functions]
|
193 |
+
embed = embed[..., None, :, :] + self._zero_half_pi[..., None, None]
|
194 |
+
# Use the trig identity cos(x) = sin(x + pi/2)
|
195 |
+
# and do one vectorized call to sin([x, x+pi/2]) instead of (sin(x), cos(x)).
|
196 |
+
embed = embed.sin()
|
197 |
+
if diag_cov is not None:
|
198 |
+
x_var = diag_cov[..., None] * torch.pow(self._frequencies, 2)
|
199 |
+
exp_var = torch.exp(-0.5 * x_var)
|
200 |
+
# [..., 2, dim, n_harmonic_functions]
|
201 |
+
embed = embed * exp_var[..., None, :, :]
|
202 |
+
|
203 |
+
embed = embed.reshape(*x.shape[:-1], -1)
|
204 |
+
|
205 |
+
if self.append_input:
|
206 |
+
return torch.cat([embed, x], dim=-1)
|
207 |
+
return embed
|
208 |
+
|
209 |
+
@staticmethod
|
210 |
+
def get_output_dim_static(
|
211 |
+
input_dims: int, n_harmonic_functions: int, append_input: bool
|
212 |
+
) -> int:
|
213 |
+
"""
|
214 |
+
Utility to help predict the shape of the output of `forward`.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
input_dims: length of the last dimension of the input tensor
|
218 |
+
n_harmonic_functions: number of embedding frequencies
|
219 |
+
append_input: whether or not to concat the original
|
220 |
+
input to the harmonic embedding
|
221 |
+
Returns:
|
222 |
+
int: the length of the last dimension of the output tensor
|
223 |
+
"""
|
224 |
+
return input_dims * (2 * n_harmonic_functions + int(append_input))
|
225 |
+
|
226 |
+
def get_output_dim(self, input_dims: int = 3) -> int:
|
227 |
+
"""
|
228 |
+
Same as above. The default for input_dims is 3 for 3D applications
|
229 |
+
which use harmonic embedding for positional encoding,
|
230 |
+
so the input might be xyz.
|
231 |
+
"""
|
232 |
+
return self.get_output_dim_static(
|
233 |
+
input_dims, len(self._frequencies), self.append_input
|
234 |
+
)
|
235 |
+
|
236 |
+
|
237 |
+
|
238 |
+
|
239 |
+
class PoseEmbedding(nn.Module):
|
240 |
+
def __init__(self, target_dim, n_harmonic_functions=10, append_input=True):
|
241 |
+
super().__init__()
|
242 |
+
|
243 |
+
self._emb_pose = HarmonicEmbedding(
|
244 |
+
n_harmonic_functions=n_harmonic_functions, append_input=append_input
|
245 |
+
)
|
246 |
+
|
247 |
+
self.out_dim = self._emb_pose.get_output_dim(target_dim)
|
248 |
+
|
249 |
+
def forward(self, pose_encoding):
|
250 |
+
e_pose_encoding = self._emb_pose(pose_encoding)
|
251 |
+
return e_pose_encoding
|
252 |
+
|
253 |
+
|
254 |
+
|
255 |
+
|
256 |
+
def random_mask_single_patch_vectorized(images, patch_size=(16, 16)):
|
257 |
+
"""
|
258 |
+
Randomly masks a single patch in a batch of images using fully vectorized operations.
|
259 |
+
:param images: Tensor of shape [B, 3, H, W]
|
260 |
+
:param patch_size: Tuple (ph, pw), size of the patch to mask
|
261 |
+
"""
|
262 |
+
B, C, H, W = images.shape
|
263 |
+
ph, pw = patch_size
|
264 |
+
|
265 |
+
# Generate random positions for the top-left corner of the patch
|
266 |
+
x_positions = torch.randint(0, W - pw, (B, 1, 1))
|
267 |
+
y_positions = torch.randint(0, H - ph, (B, 1, 1))
|
268 |
+
|
269 |
+
# Compute patch grid indices
|
270 |
+
patch_x = torch.arange(pw).reshape(1, 1, pw)
|
271 |
+
patch_y = torch.arange(ph).reshape(1, ph, 1)
|
272 |
+
|
273 |
+
# Broadcast patch indices to each position
|
274 |
+
x_indices = x_positions + patch_x
|
275 |
+
y_indices = y_positions + patch_y
|
276 |
+
|
277 |
+
# Expand the indices to cover all channels and all images in the batch
|
278 |
+
x_indices = x_indices.expand(B, ph, pw)
|
279 |
+
y_indices = y_indices.expand(B, ph, pw)
|
280 |
+
|
281 |
+
# Flatten the indices to apply the mask using advanced indexing
|
282 |
+
batch_indices = torch.arange(B).unsqueeze(-1).expand(B, ph * pw)
|
283 |
+
x_indices = x_indices.reshape(B, ph * pw)
|
284 |
+
y_indices = y_indices.reshape(B, ph * pw)
|
285 |
+
|
286 |
+
# Create a mask initialized to one and apply zero at the indices
|
287 |
+
mask = torch.ones_like(images)
|
288 |
+
mask[batch_indices, :, y_indices, x_indices] = 0
|
289 |
+
|
290 |
+
# Apply mask to images
|
291 |
+
return images * mask
|
292 |
+
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
|
297 |
+
# borrowed from https://github.com/microsoft/moge
|
298 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
299 |
+
if aspect_ratio is None:
|
300 |
+
aspect_ratio = width / height
|
301 |
+
|
302 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
303 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
304 |
+
|
305 |
+
u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
|
306 |
+
v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
|
307 |
+
u, v = torch.meshgrid(u, v, indexing='xy')
|
308 |
+
uv = torch.stack([u, v], dim=-1)
|
309 |
+
return uv
|
vggt/layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .dino_head import DINOHead
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
vggt/layers/attention.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
from torch import Tensor
|
15 |
+
from torch import nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.getLogger("dinov2")
|
20 |
+
|
21 |
+
|
22 |
+
# XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
23 |
+
# try:
|
24 |
+
# if XFORMERS_ENABLED:
|
25 |
+
# from xformers.ops import memory_efficient_attention, unbind
|
26 |
+
|
27 |
+
# XFORMERS_AVAILABLE = True
|
28 |
+
# warnings.warn("xFormers is available (Attention)")
|
29 |
+
# else:
|
30 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
31 |
+
# raise ImportError
|
32 |
+
# except ImportError:
|
33 |
+
# XFORMERS_AVAILABLE = False
|
34 |
+
# warnings.warn("xFormers is not available (Attention)")
|
35 |
+
|
36 |
+
XFORMERS_AVAILABLE = False
|
37 |
+
|
38 |
+
|
39 |
+
class Attention(nn.Module):
|
40 |
+
def __init__(
|
41 |
+
self,
|
42 |
+
dim: int,
|
43 |
+
num_heads: int = 8,
|
44 |
+
qkv_bias: bool = True,
|
45 |
+
qk_norm: bool = False,
|
46 |
+
attn_drop: float = 0.,
|
47 |
+
proj_drop: float = 0.,
|
48 |
+
proj_bias: bool = True,
|
49 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
50 |
+
fused_attn: bool = True,
|
51 |
+
rope = None,
|
52 |
+
) -> None:
|
53 |
+
super().__init__()
|
54 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
55 |
+
self.num_heads = num_heads
|
56 |
+
self.head_dim = dim // num_heads
|
57 |
+
self.scale = self.head_dim ** -0.5
|
58 |
+
self.fused_attn = fused_attn
|
59 |
+
|
60 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
61 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
62 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
63 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
64 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
65 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
66 |
+
self.rope = rope
|
67 |
+
|
68 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
69 |
+
B, N, C = x.shape
|
70 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
71 |
+
q, k, v = qkv.unbind(0)
|
72 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
73 |
+
|
74 |
+
if self.rope is not None:
|
75 |
+
q = self.rope(q, pos)
|
76 |
+
k = self.rope(k, pos)
|
77 |
+
|
78 |
+
if self.fused_attn:
|
79 |
+
x = F.scaled_dot_product_attention(
|
80 |
+
q, k, v,
|
81 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
q = q * self.scale
|
85 |
+
attn = q @ k.transpose(-2, -1)
|
86 |
+
attn = attn.softmax(dim=-1)
|
87 |
+
attn = self.attn_drop(attn)
|
88 |
+
x = attn @ v
|
89 |
+
|
90 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
91 |
+
x = self.proj(x)
|
92 |
+
x = self.proj_drop(x)
|
93 |
+
return x
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
class MemEffAttention(Attention):
|
99 |
+
def forward(self, x: Tensor, attn_bias=None, pos=None) -> Tensor:
|
100 |
+
assert pos is None
|
101 |
+
if not XFORMERS_AVAILABLE:
|
102 |
+
if attn_bias is not None:
|
103 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
104 |
+
return super().forward(x)
|
105 |
+
|
106 |
+
B, N, C = x.shape
|
107 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
108 |
+
|
109 |
+
q, k, v = unbind(qkv, 2)
|
110 |
+
|
111 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
112 |
+
x = x.reshape([B, N, C])
|
113 |
+
|
114 |
+
x = self.proj(x)
|
115 |
+
x = self.proj_drop(x)
|
116 |
+
return x
|
vggt/layers/block.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn, Tensor
|
17 |
+
|
18 |
+
from .attention import Attention, MemEffAttention
|
19 |
+
from .drop_path import DropPath
|
20 |
+
from .layer_scale import LayerScale
|
21 |
+
from .mlp import Mlp
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger("dinov2")
|
25 |
+
|
26 |
+
|
27 |
+
# XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
28 |
+
# try:
|
29 |
+
# if XFORMERS_ENABLED:
|
30 |
+
# from xformers.ops import fmha, scaled_index_add, index_select_cat
|
31 |
+
|
32 |
+
# XFORMERS_AVAILABLE = True
|
33 |
+
# warnings.warn("xFormers is available (Block)")
|
34 |
+
# else:
|
35 |
+
# warnings.warn("xFormers is disabled (Block)")
|
36 |
+
# raise ImportError
|
37 |
+
# except ImportError:
|
38 |
+
# XFORMERS_AVAILABLE = False
|
39 |
+
|
40 |
+
# warnings.warn("xFormers is not available (Block)")
|
41 |
+
|
42 |
+
XFORMERS_AVAILABLE = False
|
43 |
+
|
44 |
+
class Block(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
dim: int,
|
48 |
+
num_heads: int,
|
49 |
+
mlp_ratio: float = 4.0,
|
50 |
+
qkv_bias: bool = True,
|
51 |
+
qk_norm: bool = False,
|
52 |
+
proj_bias: bool = True,
|
53 |
+
ffn_bias: bool = True,
|
54 |
+
fused_attn: bool = True,
|
55 |
+
drop: float = 0.0,
|
56 |
+
attn_drop: float = 0.0,
|
57 |
+
init_values=None,
|
58 |
+
drop_path: float = 0.0,
|
59 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
60 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
61 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
62 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
63 |
+
rope_freq: int = -1,
|
64 |
+
rope = None,
|
65 |
+
) -> None:
|
66 |
+
super().__init__()
|
67 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
68 |
+
self.norm1 = norm_layer(dim)
|
69 |
+
|
70 |
+
self.attn = attn_class(
|
71 |
+
dim,
|
72 |
+
num_heads=num_heads,
|
73 |
+
qkv_bias=qkv_bias,
|
74 |
+
qk_norm=qk_norm,
|
75 |
+
proj_bias=proj_bias,
|
76 |
+
attn_drop=attn_drop,
|
77 |
+
proj_drop=drop,
|
78 |
+
fused_attn=fused_attn,
|
79 |
+
rope=rope,
|
80 |
+
)
|
81 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
82 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
83 |
+
|
84 |
+
self.norm2 = norm_layer(dim)
|
85 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
86 |
+
self.mlp = ffn_layer(
|
87 |
+
in_features=dim,
|
88 |
+
hidden_features=mlp_hidden_dim,
|
89 |
+
act_layer=act_layer,
|
90 |
+
drop=drop,
|
91 |
+
bias=ffn_bias,
|
92 |
+
)
|
93 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
94 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
95 |
+
|
96 |
+
self.sample_drop_ratio = drop_path
|
97 |
+
|
98 |
+
def forward(self, x: Tensor, pos=None) -> Tensor:
|
99 |
+
def attn_residual_func(x: Tensor, pos=None) -> Tensor:
|
100 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos))
|
101 |
+
|
102 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
103 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
104 |
+
|
105 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
106 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
107 |
+
x = drop_add_residual_stochastic_depth(
|
108 |
+
x,
|
109 |
+
pos=pos,
|
110 |
+
residual_func=attn_residual_func,
|
111 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
112 |
+
)
|
113 |
+
x = drop_add_residual_stochastic_depth(
|
114 |
+
x,
|
115 |
+
residual_func=ffn_residual_func,
|
116 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
117 |
+
)
|
118 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
119 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos))
|
120 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
121 |
+
else:
|
122 |
+
x = x + attn_residual_func(x, pos=pos)
|
123 |
+
x = x + ffn_residual_func(x)
|
124 |
+
return x
|
125 |
+
|
126 |
+
|
127 |
+
def drop_add_residual_stochastic_depth(
|
128 |
+
x: Tensor,
|
129 |
+
residual_func: Callable[[Tensor], Tensor],
|
130 |
+
sample_drop_ratio: float = 0.0,
|
131 |
+
pos = None,
|
132 |
+
) -> Tensor:
|
133 |
+
# 1) extract subset using permutation
|
134 |
+
b, n, d = x.shape
|
135 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
136 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
137 |
+
x_subset = x[brange]
|
138 |
+
|
139 |
+
# 2) apply residual_func to get residual
|
140 |
+
if pos is not None:
|
141 |
+
pos = pos[brange]
|
142 |
+
residual = residual_func(x_subset, pos=pos)
|
143 |
+
else:
|
144 |
+
residual = residual_func(x_subset)
|
145 |
+
|
146 |
+
x_flat = x.flatten(1)
|
147 |
+
residual = residual.flatten(1)
|
148 |
+
|
149 |
+
residual_scale_factor = b / sample_subset_size
|
150 |
+
|
151 |
+
# 3) add the residual
|
152 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
153 |
+
return x_plus_residual.view_as(x)
|
154 |
+
|
155 |
+
|
156 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
157 |
+
b, n, d = x.shape
|
158 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
159 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
160 |
+
residual_scale_factor = b / sample_subset_size
|
161 |
+
return brange, residual_scale_factor
|
162 |
+
|
163 |
+
|
164 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
165 |
+
if scaling_vector is None:
|
166 |
+
x_flat = x.flatten(1)
|
167 |
+
residual = residual.flatten(1)
|
168 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
169 |
+
else:
|
170 |
+
x_plus_residual = scaled_index_add(
|
171 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
172 |
+
)
|
173 |
+
return x_plus_residual
|
174 |
+
|
175 |
+
|
176 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
177 |
+
|
178 |
+
|
179 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
180 |
+
"""
|
181 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
182 |
+
"""
|
183 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
184 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
185 |
+
if all_shapes not in attn_bias_cache.keys():
|
186 |
+
seqlens = []
|
187 |
+
for b, x in zip(batch_sizes, x_list):
|
188 |
+
for _ in range(b):
|
189 |
+
seqlens.append(x.shape[1])
|
190 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
191 |
+
attn_bias._batch_sizes = batch_sizes
|
192 |
+
attn_bias_cache[all_shapes] = attn_bias
|
193 |
+
|
194 |
+
if branges is not None:
|
195 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
196 |
+
else:
|
197 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
198 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
199 |
+
|
200 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
201 |
+
|
202 |
+
|
203 |
+
def drop_add_residual_stochastic_depth_list(
|
204 |
+
x_list: List[Tensor],
|
205 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
206 |
+
sample_drop_ratio: float = 0.0,
|
207 |
+
scaling_vector=None,
|
208 |
+
) -> Tensor:
|
209 |
+
# 1) generate random set of indices for dropping samples in the batch
|
210 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
211 |
+
branges = [s[0] for s in branges_scales]
|
212 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
213 |
+
|
214 |
+
# 2) get attention bias and index+concat the tensors
|
215 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
216 |
+
|
217 |
+
# 3) apply residual_func to get residual, and split the result
|
218 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
219 |
+
|
220 |
+
outputs = []
|
221 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
222 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
223 |
+
return outputs
|
224 |
+
|
225 |
+
|
226 |
+
class NestedTensorBlock(Block):
|
227 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
228 |
+
"""
|
229 |
+
x_list contains a list of tensors to nest together and run
|
230 |
+
"""
|
231 |
+
assert isinstance(self.attn, MemEffAttention)
|
232 |
+
|
233 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
234 |
+
|
235 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
236 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
237 |
+
|
238 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
239 |
+
return self.mlp(self.norm2(x))
|
240 |
+
|
241 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
242 |
+
x_list,
|
243 |
+
residual_func=attn_residual_func,
|
244 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
245 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
246 |
+
)
|
247 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
248 |
+
x_list,
|
249 |
+
residual_func=ffn_residual_func,
|
250 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
251 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
252 |
+
)
|
253 |
+
return x_list
|
254 |
+
else:
|
255 |
+
|
256 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
257 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
258 |
+
|
259 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
260 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
261 |
+
|
262 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
263 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
264 |
+
x = x + ffn_residual_func(x)
|
265 |
+
return attn_bias.split(x)
|
266 |
+
|
267 |
+
def forward(self, x_or_x_list):
|
268 |
+
if isinstance(x_or_x_list, Tensor):
|
269 |
+
return super().forward(x_or_x_list)
|
270 |
+
elif isinstance(x_or_x_list, list):
|
271 |
+
if not XFORMERS_AVAILABLE:
|
272 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
273 |
+
return self.forward_nested(x_or_x_list)
|
274 |
+
else:
|
275 |
+
raise AssertionError
|
vggt/layers/dino_head.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn.init import trunc_normal_
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
|
12 |
+
class DINOHead(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
in_dim,
|
16 |
+
out_dim,
|
17 |
+
use_bn=False,
|
18 |
+
nlayers=3,
|
19 |
+
hidden_dim=2048,
|
20 |
+
bottleneck_dim=256,
|
21 |
+
mlp_bias=True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
nlayers = max(nlayers, 1)
|
25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
26 |
+
self.apply(self._init_weights)
|
27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
28 |
+
self.last_layer.weight_g.data.fill_(1)
|
29 |
+
|
30 |
+
def _init_weights(self, m):
|
31 |
+
if isinstance(m, nn.Linear):
|
32 |
+
trunc_normal_(m.weight, std=0.02)
|
33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
34 |
+
nn.init.constant_(m.bias, 0)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.mlp(x)
|
38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
40 |
+
x = self.last_layer(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
45 |
+
if nlayers == 1:
|
46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
47 |
+
else:
|
48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
49 |
+
if use_bn:
|
50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
51 |
+
layers.append(nn.GELU())
|
52 |
+
for _ in range(nlayers - 2):
|
53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
54 |
+
if use_bn:
|
55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
56 |
+
layers.append(nn.GELU())
|
57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
58 |
+
return nn.Sequential(*layers)
|
vggt/layers/drop_path.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
9 |
+
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
15 |
+
if drop_prob == 0.0 or not training:
|
16 |
+
return x
|
17 |
+
keep_prob = 1 - drop_prob
|
18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
20 |
+
if keep_prob > 0.0:
|
21 |
+
random_tensor.div_(keep_prob)
|
22 |
+
output = x * random_tensor
|
23 |
+
return output
|
24 |
+
|
25 |
+
|
26 |
+
class DropPath(nn.Module):
|
27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
28 |
+
|
29 |
+
def __init__(self, drop_prob=None):
|
30 |
+
super(DropPath, self).__init__()
|
31 |
+
self.drop_prob = drop_prob
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return drop_path(x, self.drop_prob, self.training)
|
vggt/layers/layer_scale.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
7 |
+
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
class LayerScale(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
dim: int,
|
19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
20 |
+
inplace: bool = False,
|
21 |
+
) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.inplace = inplace
|
24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
25 |
+
|
26 |
+
def forward(self, x: Tensor) -> Tensor:
|
27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
vggt/layers/mlp.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
9 |
+
|
10 |
+
|
11 |
+
from typing import Callable, Optional
|
12 |
+
|
13 |
+
from torch import Tensor, nn
|
14 |
+
|
15 |
+
|
16 |
+
class Mlp(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
in_features: int,
|
20 |
+
hidden_features: Optional[int] = None,
|
21 |
+
out_features: Optional[int] = None,
|
22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
23 |
+
drop: float = 0.0,
|
24 |
+
bias: bool = True,
|
25 |
+
) -> None:
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
30 |
+
self.act = act_layer()
|
31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
32 |
+
self.drop = nn.Dropout(drop)
|
33 |
+
|
34 |
+
def forward(self, x: Tensor) -> Tensor:
|
35 |
+
x = self.fc1(x)
|
36 |
+
x = self.act(x)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.fc2(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
return x
|
vggt/layers/patch_embed.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
from typing import Callable, Optional, Tuple, Union
|
11 |
+
|
12 |
+
from torch import Tensor
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
|
16 |
+
def make_2tuple(x):
|
17 |
+
if isinstance(x, tuple):
|
18 |
+
assert len(x) == 2
|
19 |
+
return x
|
20 |
+
|
21 |
+
assert isinstance(x, int)
|
22 |
+
return (x, x)
|
23 |
+
|
24 |
+
|
25 |
+
class PatchEmbed(nn.Module):
|
26 |
+
"""
|
27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
28 |
+
|
29 |
+
Args:
|
30 |
+
img_size: Image size.
|
31 |
+
patch_size: Patch token size.
|
32 |
+
in_chans: Number of input image channels.
|
33 |
+
embed_dim: Number of linear projection output channels.
|
34 |
+
norm_layer: Normalization layer.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
41 |
+
in_chans: int = 3,
|
42 |
+
embed_dim: int = 768,
|
43 |
+
norm_layer: Optional[Callable] = None,
|
44 |
+
flatten_embedding: bool = True,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
image_HW = make_2tuple(img_size)
|
49 |
+
patch_HW = make_2tuple(patch_size)
|
50 |
+
patch_grid_size = (
|
51 |
+
image_HW[0] // patch_HW[0],
|
52 |
+
image_HW[1] // patch_HW[1],
|
53 |
+
)
|
54 |
+
|
55 |
+
self.img_size = image_HW
|
56 |
+
self.patch_size = patch_HW
|
57 |
+
self.patches_resolution = patch_grid_size
|
58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
59 |
+
|
60 |
+
self.in_chans = in_chans
|
61 |
+
self.embed_dim = embed_dim
|
62 |
+
|
63 |
+
self.flatten_embedding = flatten_embedding
|
64 |
+
|
65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
67 |
+
|
68 |
+
def forward(self, x: Tensor) -> Tensor:
|
69 |
+
_, _, H, W = x.shape
|
70 |
+
patch_H, patch_W = self.patch_size
|
71 |
+
|
72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
74 |
+
|
75 |
+
x = self.proj(x) # B C H W
|
76 |
+
H, W = x.size(2), x.size(3)
|
77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
78 |
+
x = self.norm(x)
|
79 |
+
if not self.flatten_embedding:
|
80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
81 |
+
return x
|
82 |
+
|
83 |
+
def flops(self) -> float:
|
84 |
+
Ho, Wo = self.patches_resolution
|
85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
86 |
+
if self.norm is not None:
|
87 |
+
flops += Ho * Wo * self.embed_dim
|
88 |
+
return flops
|
vggt/layers/rope.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class PositionGetter(object):
|
6 |
+
""" return positions of patches """
|
7 |
+
|
8 |
+
# NOTE this can take a lot of memory when the patch size is variable
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
self.cache_positions = {}
|
12 |
+
|
13 |
+
def __call__(self, b, h, w, device):
|
14 |
+
if not (h,w) in self.cache_positions:
|
15 |
+
x = torch.arange(w, device=device)
|
16 |
+
y = torch.arange(h, device=device)
|
17 |
+
self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
|
18 |
+
pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
|
19 |
+
return pos
|
20 |
+
|
21 |
+
|
22 |
+
# --------------------------------------------------------
|
23 |
+
# 2D sine-cosine position embedding
|
24 |
+
# References:
|
25 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
26 |
+
# Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
|
27 |
+
# MoCo v3: https://github.com/facebookresearch/moco-v3
|
28 |
+
# --------------------------------------------------------
|
29 |
+
def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
|
30 |
+
"""
|
31 |
+
grid_size: tuple (height, width) of the grid
|
32 |
+
return:
|
33 |
+
pos_embed: [grid_size[0]*grid_size[1], embed_dim] or [n_cls_token+grid_size[0]*grid_size[1], embed_dim] (w/ or w/o cls_token)
|
34 |
+
"""
|
35 |
+
grid_h = np.arange(grid_size[0], dtype=np.float32)
|
36 |
+
grid_w = np.arange(grid_size[1], dtype=np.float32)
|
37 |
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
38 |
+
grid = np.stack(grid, axis=0)
|
39 |
+
|
40 |
+
grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
|
41 |
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
42 |
+
if n_cls_token>0:
|
43 |
+
pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
|
44 |
+
return pos_embed
|
45 |
+
|
46 |
+
|
47 |
+
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
48 |
+
assert embed_dim % 2 == 0
|
49 |
+
|
50 |
+
# use half of dimensions to encode grid_h
|
51 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
52 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
53 |
+
|
54 |
+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
55 |
+
return emb
|
56 |
+
|
57 |
+
|
58 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
59 |
+
"""
|
60 |
+
embed_dim: output dimension for each position
|
61 |
+
pos: a list of positions to be encoded: size (M,)
|
62 |
+
out: (M, D)
|
63 |
+
"""
|
64 |
+
assert embed_dim % 2 == 0
|
65 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
66 |
+
omega /= embed_dim / 2.
|
67 |
+
omega = 1. / 10000**omega # (D/2,)
|
68 |
+
|
69 |
+
pos = pos.reshape(-1) # (M,)
|
70 |
+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
71 |
+
|
72 |
+
emb_sin = np.sin(out) # (M, D/2)
|
73 |
+
emb_cos = np.cos(out) # (M, D/2)
|
74 |
+
|
75 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
76 |
+
return emb
|
77 |
+
|
78 |
+
|
79 |
+
# --------------------------------------------------------
|
80 |
+
# Interpolate position embeddings for high-resolution
|
81 |
+
# References:
|
82 |
+
# MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
|
83 |
+
# DeiT: https://github.com/facebookresearch/deit
|
84 |
+
# --------------------------------------------------------
|
85 |
+
def interpolate_pos_embed(model, checkpoint_model):
|
86 |
+
keys = ['enc_pos_embed']+(['dec_pos_embed'] if hasattr(model,'dec_blocks') else [])
|
87 |
+
img_size = model.patch_embed.img_size
|
88 |
+
if isinstance(img_size,int): img_size = (img_size,img_size)
|
89 |
+
for k in keys:
|
90 |
+
if not k in checkpoint_model: continue
|
91 |
+
pos_embed_checkpoint = checkpoint_model[k]
|
92 |
+
embedding_size = pos_embed_checkpoint.shape[-1]
|
93 |
+
num_extra_tokens = 0 # no cls token
|
94 |
+
# height (== width) for the checkpoint position embedding
|
95 |
+
orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
|
96 |
+
new_size = (img_size[0]//model.patch_embed.patch_size[0],img_size[1]//model.patch_embed.patch_size[1])
|
97 |
+
if orig_size != new_size[0] or orig_size != new_size[1]:
|
98 |
+
print("Position interpolate %s from %dx%d to %dx%d" % (k, orig_size, orig_size, new_size[0], new_size[1]))
|
99 |
+
extra_tokens = pos_embed_checkpoint[:num_extra_tokens,:]
|
100 |
+
pos_tokens = pos_embed_checkpoint[num_extra_tokens:,:]
|
101 |
+
pos_tokens = pos_tokens.reshape(1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
102 |
+
pos_tokens = torch.nn.functional.interpolate(pos_tokens, size=(new_size[0], new_size[1]), mode='bicubic', align_corners=False)
|
103 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2).squeeze(0)
|
104 |
+
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=0)
|
105 |
+
checkpoint_model[k] = new_pos_embed.squeeze(0)
|
106 |
+
|
107 |
+
#----------------------------------------------------------
|
108 |
+
# RoPE2D: RoPE implementation in 2D
|
109 |
+
#----------------------------------------------------------
|
110 |
+
|
111 |
+
# borrowed from https://github.com/naver/dust3r
|
112 |
+
# todo: replace with our official implementation
|
113 |
+
|
114 |
+
class RoPE2D(torch.nn.Module):
|
115 |
+
def __init__(self, freq=100.0, F0=1.0):
|
116 |
+
super().__init__()
|
117 |
+
self.base = freq
|
118 |
+
self.F0 = F0
|
119 |
+
self.cache = {}
|
120 |
+
|
121 |
+
def get_cos_sin(self, D, seq_len, device, dtype):
|
122 |
+
if (D,seq_len,device,dtype) not in self.cache:
|
123 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
|
124 |
+
t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
125 |
+
freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
|
126 |
+
freqs = torch.cat((freqs, freqs), dim=-1)
|
127 |
+
cos = freqs.cos() # (Seq, Dim)
|
128 |
+
sin = freqs.sin()
|
129 |
+
self.cache[D,seq_len,device,dtype] = (cos,sin)
|
130 |
+
return self.cache[D,seq_len,device,dtype]
|
131 |
+
|
132 |
+
@staticmethod
|
133 |
+
def rotate_half(x):
|
134 |
+
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
|
135 |
+
return torch.cat((-x2, x1), dim=-1)
|
136 |
+
|
137 |
+
def apply_rope1d(self, tokens, pos1d, cos, sin):
|
138 |
+
assert pos1d.ndim==2
|
139 |
+
cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
|
140 |
+
sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
|
141 |
+
return (tokens * cos) + (self.rotate_half(tokens) * sin)
|
142 |
+
|
143 |
+
def forward(self, tokens, positions):
|
144 |
+
"""
|
145 |
+
input:
|
146 |
+
* tokens: batch_size x nheads x ntokens x dim
|
147 |
+
* positions: batch_size x ntokens x 2 (y and x position of each token)
|
148 |
+
output:
|
149 |
+
* tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
|
150 |
+
"""
|
151 |
+
assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
|
152 |
+
D = tokens.size(3) // 2
|
153 |
+
assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
|
154 |
+
cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
|
155 |
+
# split features into two along the feature dimension, and apply rope1d on each half
|
156 |
+
y, x = tokens.chunk(2, dim=-1)
|
157 |
+
y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
|
158 |
+
x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
|
159 |
+
tokens = torch.cat((y, x), dim=-1)
|
160 |
+
return tokens
|
vggt/layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
from typing import Callable, Optional
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
from torch import Tensor, nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class SwiGLUFFN(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_features: int,
|
18 |
+
hidden_features: Optional[int] = None,
|
19 |
+
out_features: Optional[int] = None,
|
20 |
+
act_layer: Callable[..., nn.Module] = None,
|
21 |
+
drop: float = 0.0,
|
22 |
+
bias: bool = True,
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
out_features = out_features or in_features
|
26 |
+
hidden_features = hidden_features or in_features
|
27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
x12 = self.w12(x)
|
32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
33 |
+
hidden = F.silu(x1) * x2
|
34 |
+
return self.w3(hidden)
|
35 |
+
|
36 |
+
|
37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
38 |
+
# try:
|
39 |
+
# if XFORMERS_ENABLED:
|
40 |
+
# from xformers.ops import SwiGLU
|
41 |
+
|
42 |
+
# XFORMERS_AVAILABLE = True
|
43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
44 |
+
# else:
|
45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
46 |
+
# raise ImportError
|
47 |
+
# except ImportError:
|
48 |
+
SwiGLU = SwiGLUFFN
|
49 |
+
XFORMERS_AVAILABLE = False
|
50 |
+
|
51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
52 |
+
|
53 |
+
|
54 |
+
class SwiGLUFFNFused(SwiGLU):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_features: int,
|
58 |
+
hidden_features: Optional[int] = None,
|
59 |
+
out_features: Optional[int] = None,
|
60 |
+
act_layer: Callable[..., nn.Module] = None,
|
61 |
+
drop: float = 0.0,
|
62 |
+
bias: bool = True,
|
63 |
+
) -> None:
|
64 |
+
out_features = out_features or in_features
|
65 |
+
hidden_features = hidden_features or in_features
|
66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
67 |
+
super().__init__(
|
68 |
+
in_features=in_features,
|
69 |
+
hidden_features=hidden_features,
|
70 |
+
out_features=out_features,
|
71 |
+
bias=bias,
|
72 |
+
)
|
vggt/layers/vision_transformer.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
from torch.utils.checkpoint import checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
20 |
+
|
21 |
+
logger = logging.getLogger("dinov2")
|
22 |
+
|
23 |
+
|
24 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
25 |
+
if not depth_first and include_root:
|
26 |
+
fn(module=module, name=name)
|
27 |
+
for child_name, child_module in module.named_children():
|
28 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
29 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
30 |
+
if depth_first and include_root:
|
31 |
+
fn(module=module, name=name)
|
32 |
+
return module
|
33 |
+
|
34 |
+
|
35 |
+
class BlockChunk(nn.ModuleList):
|
36 |
+
def forward(self, x):
|
37 |
+
for b in self:
|
38 |
+
x = b(x)
|
39 |
+
return x
|
40 |
+
|
41 |
+
|
42 |
+
class DinoVisionTransformer(nn.Module):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
img_size=224,
|
46 |
+
patch_size=16,
|
47 |
+
in_chans=3,
|
48 |
+
embed_dim=768,
|
49 |
+
depth=12,
|
50 |
+
num_heads=12,
|
51 |
+
mlp_ratio=4.0,
|
52 |
+
qkv_bias=True,
|
53 |
+
ffn_bias=True,
|
54 |
+
proj_bias=True,
|
55 |
+
drop_path_rate=0.0,
|
56 |
+
drop_path_uniform=False,
|
57 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
58 |
+
embed_layer=PatchEmbed,
|
59 |
+
act_layer=nn.GELU,
|
60 |
+
block_fn=Block,
|
61 |
+
ffn_layer="mlp",
|
62 |
+
block_chunks=1,
|
63 |
+
num_register_tokens=0,
|
64 |
+
interpolate_antialias=False,
|
65 |
+
interpolate_offset=0.1,
|
66 |
+
qk_norm=False,
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
Args:
|
70 |
+
img_size (int, tuple): input image size
|
71 |
+
patch_size (int, tuple): patch size
|
72 |
+
in_chans (int): number of input channels
|
73 |
+
embed_dim (int): embedding dimension
|
74 |
+
depth (int): depth of transformer
|
75 |
+
num_heads (int): number of attention heads
|
76 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
77 |
+
qkv_bias (bool): enable bias for qkv if True
|
78 |
+
proj_bias (bool): enable bias for proj in attn if True
|
79 |
+
ffn_bias (bool): enable bias for ffn if True
|
80 |
+
drop_path_rate (float): stochastic depth rate
|
81 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
82 |
+
weight_init (str): weight init scheme
|
83 |
+
init_values (float): layer-scale init values
|
84 |
+
embed_layer (nn.Module): patch embedding layer
|
85 |
+
act_layer (nn.Module): MLP activation layer
|
86 |
+
block_fn (nn.Module): transformer block class
|
87 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
88 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
89 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
90 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
91 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
92 |
+
"""
|
93 |
+
super().__init__()
|
94 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
95 |
+
|
96 |
+
# tricky but makes it work
|
97 |
+
self.use_checkpoint = False
|
98 |
+
#
|
99 |
+
|
100 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
101 |
+
self.num_tokens = 1
|
102 |
+
self.n_blocks = depth
|
103 |
+
self.num_heads = num_heads
|
104 |
+
self.patch_size = patch_size
|
105 |
+
self.num_register_tokens = num_register_tokens
|
106 |
+
self.interpolate_antialias = interpolate_antialias
|
107 |
+
self.interpolate_offset = interpolate_offset
|
108 |
+
|
109 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
110 |
+
num_patches = self.patch_embed.num_patches
|
111 |
+
|
112 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
113 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
114 |
+
assert num_register_tokens >= 0
|
115 |
+
self.register_tokens = (
|
116 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
117 |
+
)
|
118 |
+
|
119 |
+
if drop_path_uniform is True:
|
120 |
+
dpr = [drop_path_rate] * depth
|
121 |
+
else:
|
122 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
123 |
+
|
124 |
+
if ffn_layer == "mlp":
|
125 |
+
logger.info("using MLP layer as FFN")
|
126 |
+
ffn_layer = Mlp
|
127 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
128 |
+
logger.info("using SwiGLU layer as FFN")
|
129 |
+
ffn_layer = SwiGLUFFNFused
|
130 |
+
elif ffn_layer == "identity":
|
131 |
+
logger.info("using Identity layer as FFN")
|
132 |
+
|
133 |
+
def f(*args, **kwargs):
|
134 |
+
return nn.Identity()
|
135 |
+
|
136 |
+
ffn_layer = f
|
137 |
+
else:
|
138 |
+
raise NotImplementedError
|
139 |
+
|
140 |
+
blocks_list = [
|
141 |
+
block_fn(
|
142 |
+
dim=embed_dim,
|
143 |
+
num_heads=num_heads,
|
144 |
+
mlp_ratio=mlp_ratio,
|
145 |
+
qkv_bias=qkv_bias,
|
146 |
+
proj_bias=proj_bias,
|
147 |
+
ffn_bias=ffn_bias,
|
148 |
+
drop_path=dpr[i],
|
149 |
+
norm_layer=norm_layer,
|
150 |
+
act_layer=act_layer,
|
151 |
+
ffn_layer=ffn_layer,
|
152 |
+
init_values=init_values,
|
153 |
+
qk_norm=qk_norm,
|
154 |
+
)
|
155 |
+
for i in range(depth)
|
156 |
+
]
|
157 |
+
if block_chunks > 0:
|
158 |
+
self.chunked_blocks = True
|
159 |
+
chunked_blocks = []
|
160 |
+
chunksize = depth // block_chunks
|
161 |
+
for i in range(0, depth, chunksize):
|
162 |
+
# this is to keep the block index consistent if we chunk the block list
|
163 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
164 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
165 |
+
else:
|
166 |
+
self.chunked_blocks = False
|
167 |
+
self.blocks = nn.ModuleList(blocks_list)
|
168 |
+
|
169 |
+
self.norm = norm_layer(embed_dim)
|
170 |
+
self.head = nn.Identity()
|
171 |
+
|
172 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
173 |
+
|
174 |
+
self.init_weights()
|
175 |
+
|
176 |
+
def init_weights(self):
|
177 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
178 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
179 |
+
if self.register_tokens is not None:
|
180 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
181 |
+
named_apply(init_weights_vit_timm, self)
|
182 |
+
|
183 |
+
def interpolate_pos_encoding(self, x, w, h):
|
184 |
+
previous_dtype = x.dtype
|
185 |
+
npatch = x.shape[1] - 1
|
186 |
+
N = self.pos_embed.shape[1] - 1
|
187 |
+
if npatch == N and w == h:
|
188 |
+
return self.pos_embed
|
189 |
+
pos_embed = self.pos_embed.float()
|
190 |
+
class_pos_embed = pos_embed[:, 0]
|
191 |
+
patch_pos_embed = pos_embed[:, 1:]
|
192 |
+
dim = x.shape[-1]
|
193 |
+
w0 = w // self.patch_size
|
194 |
+
h0 = h // self.patch_size
|
195 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
196 |
+
assert N == M * M
|
197 |
+
kwargs = {}
|
198 |
+
if self.interpolate_offset:
|
199 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
200 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
201 |
+
sx = float(w0 + self.interpolate_offset) / M
|
202 |
+
sy = float(h0 + self.interpolate_offset) / M
|
203 |
+
kwargs["scale_factor"] = (sx, sy)
|
204 |
+
else:
|
205 |
+
# Simply specify an output size instead of a scale factor
|
206 |
+
kwargs["size"] = (w0, h0)
|
207 |
+
patch_pos_embed = nn.functional.interpolate(
|
208 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
209 |
+
mode="bicubic",
|
210 |
+
antialias=self.interpolate_antialias,
|
211 |
+
**kwargs,
|
212 |
+
)
|
213 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
214 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
215 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
216 |
+
|
217 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
218 |
+
B, nc, w, h = x.shape
|
219 |
+
x = self.patch_embed(x)
|
220 |
+
if masks is not None:
|
221 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
222 |
+
|
223 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
224 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
225 |
+
|
226 |
+
if self.register_tokens is not None:
|
227 |
+
x = torch.cat(
|
228 |
+
(
|
229 |
+
x[:, :1],
|
230 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
231 |
+
x[:, 1:],
|
232 |
+
),
|
233 |
+
dim=1,
|
234 |
+
)
|
235 |
+
|
236 |
+
return x
|
237 |
+
|
238 |
+
def forward_features_list(self, x_list, masks_list):
|
239 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
240 |
+
|
241 |
+
|
242 |
+
for blk in self.blocks:
|
243 |
+
if self.use_checkpoint:
|
244 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
245 |
+
else:
|
246 |
+
x = blk(x)
|
247 |
+
|
248 |
+
all_x = x
|
249 |
+
output = []
|
250 |
+
for x, masks in zip(all_x, masks_list):
|
251 |
+
x_norm = self.norm(x)
|
252 |
+
output.append(
|
253 |
+
{
|
254 |
+
"x_norm_clstoken": x_norm[:, 0],
|
255 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
256 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
257 |
+
"x_prenorm": x,
|
258 |
+
"masks": masks,
|
259 |
+
}
|
260 |
+
)
|
261 |
+
return output
|
262 |
+
|
263 |
+
def forward_features(self, x, masks=None):
|
264 |
+
if isinstance(x, list):
|
265 |
+
return self.forward_features_list(x, masks)
|
266 |
+
|
267 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
268 |
+
|
269 |
+
for blk in self.blocks:
|
270 |
+
if self.use_checkpoint:
|
271 |
+
x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
|
272 |
+
else:
|
273 |
+
x = blk(x)
|
274 |
+
|
275 |
+
x_norm = self.norm(x)
|
276 |
+
return {
|
277 |
+
"x_norm_clstoken": x_norm[:, 0],
|
278 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
279 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
280 |
+
"x_prenorm": x,
|
281 |
+
"masks": masks,
|
282 |
+
}
|
283 |
+
|
284 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
285 |
+
x = self.prepare_tokens_with_masks(x)
|
286 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
287 |
+
output, total_block_len = [], len(self.blocks)
|
288 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
289 |
+
for i, blk in enumerate(self.blocks):
|
290 |
+
x = blk(x)
|
291 |
+
if i in blocks_to_take:
|
292 |
+
output.append(x)
|
293 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
294 |
+
return output
|
295 |
+
|
296 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
297 |
+
x = self.prepare_tokens_with_masks(x)
|
298 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
299 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
300 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
301 |
+
for block_chunk in self.blocks:
|
302 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
303 |
+
x = blk(x)
|
304 |
+
if i in blocks_to_take:
|
305 |
+
output.append(x)
|
306 |
+
i += 1
|
307 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
308 |
+
return output
|
309 |
+
|
310 |
+
def get_intermediate_layers(
|
311 |
+
self,
|
312 |
+
x: torch.Tensor,
|
313 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
314 |
+
reshape: bool = False,
|
315 |
+
return_class_token: bool = False,
|
316 |
+
norm=True,
|
317 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
318 |
+
if self.chunked_blocks:
|
319 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
320 |
+
else:
|
321 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
322 |
+
if norm:
|
323 |
+
outputs = [self.norm(out) for out in outputs]
|
324 |
+
class_tokens = [out[:, 0] for out in outputs]
|
325 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
326 |
+
if reshape:
|
327 |
+
B, _, w, h = x.shape
|
328 |
+
outputs = [
|
329 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
330 |
+
for out in outputs
|
331 |
+
]
|
332 |
+
if return_class_token:
|
333 |
+
return tuple(zip(outputs, class_tokens))
|
334 |
+
return tuple(outputs)
|
335 |
+
|
336 |
+
def forward(self, *args, is_training=True, **kwargs):
|
337 |
+
ret = self.forward_features(*args, **kwargs)
|
338 |
+
if is_training:
|
339 |
+
return ret
|
340 |
+
else:
|
341 |
+
return self.head(ret["x_norm_clstoken"])
|
342 |
+
|
343 |
+
|
344 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
345 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
346 |
+
if isinstance(module, nn.Linear):
|
347 |
+
trunc_normal_(module.weight, std=0.02)
|
348 |
+
if module.bias is not None:
|
349 |
+
nn.init.zeros_(module.bias)
|
350 |
+
|
351 |
+
|
352 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
353 |
+
model = DinoVisionTransformer(
|
354 |
+
patch_size=patch_size,
|
355 |
+
embed_dim=384,
|
356 |
+
depth=12,
|
357 |
+
num_heads=6,
|
358 |
+
mlp_ratio=4,
|
359 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
360 |
+
num_register_tokens=num_register_tokens,
|
361 |
+
**kwargs,
|
362 |
+
)
|
363 |
+
return model
|
364 |
+
|
365 |
+
|
366 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
367 |
+
model = DinoVisionTransformer(
|
368 |
+
patch_size=patch_size,
|
369 |
+
embed_dim=768,
|
370 |
+
depth=12,
|
371 |
+
num_heads=12,
|
372 |
+
mlp_ratio=4,
|
373 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
374 |
+
num_register_tokens=num_register_tokens,
|
375 |
+
**kwargs,
|
376 |
+
)
|
377 |
+
return model
|
378 |
+
|
379 |
+
|
380 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
381 |
+
model = DinoVisionTransformer(
|
382 |
+
patch_size=patch_size,
|
383 |
+
embed_dim=1024,
|
384 |
+
depth=24,
|
385 |
+
num_heads=16,
|
386 |
+
mlp_ratio=4,
|
387 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
388 |
+
num_register_tokens=num_register_tokens,
|
389 |
+
**kwargs,
|
390 |
+
)
|
391 |
+
return model
|
392 |
+
|
393 |
+
|
394 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
395 |
+
"""
|
396 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
397 |
+
"""
|
398 |
+
model = DinoVisionTransformer(
|
399 |
+
patch_size=patch_size,
|
400 |
+
embed_dim=1536,
|
401 |
+
depth=40,
|
402 |
+
num_heads=24,
|
403 |
+
mlp_ratio=4,
|
404 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
405 |
+
num_register_tokens=num_register_tokens,
|
406 |
+
**kwargs,
|
407 |
+
)
|
408 |
+
return model
|
vggt/models/aggregator.py
ADDED
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import logging
|
9 |
+
|
10 |
+
import pdb
|
11 |
+
import math
|
12 |
+
import numpy as np
|
13 |
+
import os
|
14 |
+
import torch
|
15 |
+
import torch.nn as nn
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from hydra.utils import instantiate
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from torch.utils.checkpoint import checkpoint
|
21 |
+
from omegaconf import OmegaConf
|
22 |
+
from contextlib import nullcontext
|
23 |
+
|
24 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
25 |
+
|
26 |
+
# from off3d.utils.train_utils import remove_if_not_match
|
27 |
+
|
28 |
+
# from off3d.models.modules import AttnBlock, CrossAttnBlock, Mlp, ResidualBlock, RoPEAttnBlock
|
29 |
+
# from vggsfm.models.utils import get_2d_sincos_pos_embed, get_1d_sincos_pos_embed_from_grid
|
30 |
+
# from off3d.models.dino_layers import SwiGLUFFNFused, PatchEmbed
|
31 |
+
|
32 |
+
from vggt.layers import SwiGLUFFNFused, PatchEmbed
|
33 |
+
from vggt.layers.block import Block
|
34 |
+
|
35 |
+
# from off3d.models.dino_layers.block import Block
|
36 |
+
# from vggt.layers.rope import RoPE2D, PositionGetter
|
37 |
+
from vggt.layers.rope import RoPE2D, PositionGetter
|
38 |
+
|
39 |
+
# from off3d.models.multihead_with_qk_norm import MultiheadAttention_with_qk_norm
|
40 |
+
# from off3d.models.rope import RoPEMulitheadAttention
|
41 |
+
|
42 |
+
from vggt.layers.vision_transformer import vit_small, vit_base, vit_large, vit_giant2
|
43 |
+
|
44 |
+
|
45 |
+
logger = logging.getLogger(__name__)
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
_RESNET_MEAN = [0.485, 0.456, 0.406]
|
50 |
+
_RESNET_STD = [0.229, 0.224, 0.225]
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
class Aggregator(nn.Module):
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
image_size = 512,
|
59 |
+
patch_size = 16,
|
60 |
+
num_register_tokens = 4,
|
61 |
+
image_backbone = "dinov2_vitl14_reg",
|
62 |
+
aa_block_size = 1,
|
63 |
+
aa_layer_size = 24,
|
64 |
+
aa_block_kwargs = Dict,
|
65 |
+
attn_block = Block,
|
66 |
+
aa_order = ["frame", "global"],
|
67 |
+
use_checkpoint = False,
|
68 |
+
use_reentrant = False,
|
69 |
+
use_dino_tokens = False,
|
70 |
+
use_patch_tokens_only = False,
|
71 |
+
freeze_dino=False,
|
72 |
+
freeze_dino_inter=False,
|
73 |
+
# pose_embed=False,
|
74 |
+
embed_type="no",
|
75 |
+
patch_embed_by_conv=False,
|
76 |
+
decoder_load_dino=False,
|
77 |
+
backbone_qk_norm=False,
|
78 |
+
**kwargs,
|
79 |
+
):
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
if image_backbone is None:
|
83 |
+
self.image_backbone = None
|
84 |
+
else:
|
85 |
+
self.__build_image_backbone__(image_backbone, image_size,
|
86 |
+
patch_size, num_register_tokens, freeze_dino=freeze_dino,
|
87 |
+
freeze_dino_inter=freeze_dino_inter, backbone_qk_norm=backbone_qk_norm)
|
88 |
+
|
89 |
+
|
90 |
+
self.freeze_dino = freeze_dino
|
91 |
+
|
92 |
+
if use_checkpoint and not freeze_dino:
|
93 |
+
self.image_backbone.use_checkpoint = True
|
94 |
+
else:
|
95 |
+
self.image_backbone.use_checkpoint = False
|
96 |
+
|
97 |
+
self.image_backbone.use_reentrant = use_reentrant
|
98 |
+
|
99 |
+
if aa_block_kwargs['rope_freq']>0:
|
100 |
+
self.rope = RoPE2D(freq=aa_block_kwargs['rope_freq'])
|
101 |
+
self.position_getter = PositionGetter()
|
102 |
+
else:
|
103 |
+
self.rope = None
|
104 |
+
|
105 |
+
frame_blocks_list = []
|
106 |
+
global_blocks_list = []
|
107 |
+
for _ in range(aa_layer_size):
|
108 |
+
frame_blocks_list.append(attn_block(**aa_block_kwargs, rope=self.rope))
|
109 |
+
global_blocks_list.append(attn_block(**aa_block_kwargs, rope=self.rope))
|
110 |
+
|
111 |
+
self.frame_blocks = nn.ModuleList(frame_blocks_list)
|
112 |
+
self.global_blocks = nn.ModuleList(global_blocks_list)
|
113 |
+
|
114 |
+
if "mlp" in embed_type:
|
115 |
+
self.register_mlp = nn.ModuleList([nn.Linear(aa_block_kwargs['dim'], aa_block_kwargs['dim']) for _ in range(aa_layer_size)])
|
116 |
+
|
117 |
+
self.aa_order = aa_order
|
118 |
+
self.aa_block_size = aa_block_size
|
119 |
+
self.aa_layer_size = aa_layer_size
|
120 |
+
|
121 |
+
assert self.aa_layer_size % self.aa_block_size == 0, "aa_layer_size must be divisible by aa_block_size"
|
122 |
+
self.aa_block_num = self.aa_layer_size // self.aa_block_size
|
123 |
+
|
124 |
+
self.patch_size = patch_size
|
125 |
+
self.use_checkpoint = use_checkpoint
|
126 |
+
self.use_reentrant = use_reentrant
|
127 |
+
self.use_dino_tokens = use_dino_tokens
|
128 |
+
self.use_patch_tokens_only = use_patch_tokens_only
|
129 |
+
# self.pose_embed = pose_embed
|
130 |
+
# self.register_embed = register_embed
|
131 |
+
self.embed_type = embed_type
|
132 |
+
|
133 |
+
if self.use_patch_tokens_only:
|
134 |
+
self.query_ref_token = nn.Parameter(torch.randn(1, 2, 1, aa_block_kwargs['dim']))
|
135 |
+
self.patch_start_idx = 0
|
136 |
+
nn.init.normal_(self.query_ref_token, std=1e-6)
|
137 |
+
elif self.use_dino_tokens:
|
138 |
+
# One for query frame and one for other frames
|
139 |
+
self.query_ref_token = nn.Parameter(torch.randn(1, 2, 1, aa_block_kwargs['dim']))
|
140 |
+
self.patch_start_idx = 1 + num_register_tokens + 1
|
141 |
+
nn.init.normal_(self.query_ref_token, std=1e-6)
|
142 |
+
else:
|
143 |
+
self.pose_token = nn.Parameter(torch.randn(1, 2, 1, aa_block_kwargs['dim']))
|
144 |
+
self.register_token = nn.Parameter(torch.randn(1, 2, num_register_tokens, aa_block_kwargs['dim']))
|
145 |
+
self.patch_start_idx = 1 + num_register_tokens
|
146 |
+
nn.init.normal_(self.pose_token, std=1e-6)
|
147 |
+
nn.init.normal_(self.register_token, std=1e-6)
|
148 |
+
|
149 |
+
|
150 |
+
if decoder_load_dino:
|
151 |
+
dinov2_weights = self.image_backbone.state_dict()
|
152 |
+
decoder_dinov2_weights = dino_to_aggregator(dinov2_weights)
|
153 |
+
missing_keys, unexpected_keys = self.load_state_dict(decoder_dinov2_weights, strict=False)
|
154 |
+
print(f"missing_keys for decoder_load_dino: {missing_keys}")
|
155 |
+
print(f"unexpected_keys for decoder_load_dino: {unexpected_keys}")
|
156 |
+
|
157 |
+
if patch_embed_by_conv:
|
158 |
+
self.image_backbone = self.image_backbone.patch_embed
|
159 |
+
|
160 |
+
|
161 |
+
for name, value in (
|
162 |
+
("_resnet_mean", _RESNET_MEAN),
|
163 |
+
("_resnet_std", _RESNET_STD),
|
164 |
+
):
|
165 |
+
self.register_buffer(
|
166 |
+
name,
|
167 |
+
torch.FloatTensor(value).view(1, 1, 3, 1, 1),
|
168 |
+
persistent=False,
|
169 |
+
)
|
170 |
+
|
171 |
+
|
172 |
+
def __build_image_backbone__(self, image_backbone, image_size, patch_size, num_register_tokens,
|
173 |
+
interpolate_antialias=True,
|
174 |
+
interpolate_offset=0.0,
|
175 |
+
block_chunks=0,
|
176 |
+
init_values=1.0,
|
177 |
+
freeze_dino=False,
|
178 |
+
freeze_dino_inter=False,
|
179 |
+
backbone_qk_norm=False,
|
180 |
+
):
|
181 |
+
|
182 |
+
vit_models = { "dinov2_vitl14_reg": vit_large,
|
183 |
+
"dinov2_vitb14_reg": vit_base,
|
184 |
+
"dinov2_vits14_reg": vit_small,
|
185 |
+
"dinov2_vitg2_reg": vit_giant2,
|
186 |
+
}
|
187 |
+
|
188 |
+
if image_backbone not in vit_models:
|
189 |
+
raise NotImplementedError
|
190 |
+
|
191 |
+
self.image_backbone = vit_models[image_backbone](img_size=image_size,
|
192 |
+
patch_size=patch_size, num_register_tokens=num_register_tokens,
|
193 |
+
interpolate_antialias=interpolate_antialias,
|
194 |
+
interpolate_offset=interpolate_offset,
|
195 |
+
block_chunks=block_chunks, init_values=init_values, qk_norm=backbone_qk_norm)
|
196 |
+
|
197 |
+
# pretrained_model = torch.hub.load("facebookresearch/dinov2", image_backbone)
|
198 |
+
# pretrained_model_dict = pretrained_model.state_dict()
|
199 |
+
# image_backbone_dict = self.image_backbone.state_dict()
|
200 |
+
|
201 |
+
# all_pretrained_keys = list(pretrained_model_dict.keys())
|
202 |
+
|
203 |
+
# for cur_key in all_pretrained_keys:
|
204 |
+
# pretrained_model_dict = remove_if_not_match(image_backbone_dict, pretrained_model_dict, cur_key)
|
205 |
+
|
206 |
+
# missing_keys, unexpected_keys = self.image_backbone.load_state_dict(pretrained_model_dict, strict=False)
|
207 |
+
|
208 |
+
self.image_backbone.mask_token.requires_grad_(False)
|
209 |
+
# self.image_backbone.freeze_dino = freeze_dino
|
210 |
+
|
211 |
+
# if freeze_dino:
|
212 |
+
# print("Freezing DINO layers")
|
213 |
+
# for name, param in self.image_backbone.named_parameters():
|
214 |
+
# param.requires_grad_(False)
|
215 |
+
|
216 |
+
# if freeze_dino_inter:
|
217 |
+
# print("Freezing DINO intermediate layers")
|
218 |
+
# for name, param in self.image_backbone.named_parameters():
|
219 |
+
# if name not in ['pos_embed', 'patch_embed.proj.weight']:
|
220 |
+
# param.requires_grad_(False)
|
221 |
+
|
222 |
+
|
223 |
+
# print("Loading pretrained DINO v2 model: ")
|
224 |
+
# print(f"missing_keys: {missing_keys}")
|
225 |
+
# print("Loading pretrained DINO v2 model: ")
|
226 |
+
# print(f"unexpected_keys: {unexpected_keys}")
|
227 |
+
|
228 |
+
|
229 |
+
def forward(
|
230 |
+
self, images,
|
231 |
+
masks=None,
|
232 |
+
batch=None,
|
233 |
+
):
|
234 |
+
"""
|
235 |
+
TODO List:
|
236 |
+
|
237 |
+
"""
|
238 |
+
|
239 |
+
# The input images are in the range of [0, 1]
|
240 |
+
B, S, C_in, H, W = images.shape
|
241 |
+
device = images.device
|
242 |
+
|
243 |
+
|
244 |
+
images = (images - self._resnet_mean) / self._resnet_std
|
245 |
+
|
246 |
+
|
247 |
+
if self.image_backbone is not None:
|
248 |
+
images = images.view(B * S, C_in, H, W)
|
249 |
+
|
250 |
+
with torch.no_grad() if self.freeze_dino else nullcontext():
|
251 |
+
backbone_output = self.image_backbone(images)
|
252 |
+
|
253 |
+
if isinstance(backbone_output, dict):
|
254 |
+
patch_tokens = backbone_output["x_norm_patchtokens"]
|
255 |
+
else:
|
256 |
+
patch_tokens = backbone_output
|
257 |
+
|
258 |
+
BS, P, C = patch_tokens.shape
|
259 |
+
|
260 |
+
if self.use_patch_tokens_only:
|
261 |
+
indicator_tokens = slice_expand_and_flatten(self.query_ref_token, B, S)
|
262 |
+
tokens = patch_tokens + indicator_tokens
|
263 |
+
elif self.use_dino_tokens:
|
264 |
+
dino_cls_token = backbone_output["x_norm_clstoken"][:, None] # BS, 1, C
|
265 |
+
dino_register_tokens = backbone_output["x_norm_regtokens"] # BS, num_register_tokens, C
|
266 |
+
|
267 |
+
indicator_tokens = slice_expand_and_flatten(self.query_ref_token, B, S)
|
268 |
+
tokens = torch.cat([dino_cls_token, dino_register_tokens, indicator_tokens, patch_tokens], dim=1)
|
269 |
+
else:
|
270 |
+
# B, S, P, C
|
271 |
+
pose_token = slice_expand_and_flatten(self.pose_token, B, S)
|
272 |
+
register_token = slice_expand_and_flatten(self.register_token, B, S)
|
273 |
+
|
274 |
+
tokens = torch.cat([pose_token, register_token, patch_tokens], dim=1)
|
275 |
+
else:
|
276 |
+
# well well I need to write this, hopefully in the near future
|
277 |
+
raise NotImplementedError
|
278 |
+
|
279 |
+
|
280 |
+
if self.rope is not None:
|
281 |
+
pos = self.position_getter(B*S, H//self.patch_size, W//self.patch_size, device=device)
|
282 |
+
else:
|
283 |
+
pos = None
|
284 |
+
|
285 |
+
|
286 |
+
|
287 |
+
if self.patch_start_idx > 0:
|
288 |
+
# shift the position by 1 so that the special tokens are at 0
|
289 |
+
pos = pos + 1
|
290 |
+
pos_special = torch.zeros(B*S, self.patch_start_idx, 2).to(device).to(pos.dtype)
|
291 |
+
pos = torch.cat([pos_special, pos], dim=1)
|
292 |
+
|
293 |
+
|
294 |
+
_, P, C = tokens.shape
|
295 |
+
|
296 |
+
|
297 |
+
frame_idx = 0
|
298 |
+
global_idx = 0
|
299 |
+
output_list = []
|
300 |
+
|
301 |
+
|
302 |
+
for aa_block_idx in range(self.aa_block_num):
|
303 |
+
for attn_type in self.aa_order:
|
304 |
+
if attn_type == "frame":
|
305 |
+
tokens, frame_idx, frame_intermediates = self._process_frame_attention(
|
306 |
+
tokens, B, S, P, C, frame_idx, self.aa_block_size, pos=pos
|
307 |
+
)
|
308 |
+
elif attn_type == "global":
|
309 |
+
tokens, global_idx, global_intermediates = self._process_global_attention(
|
310 |
+
tokens, B, S, P, C, global_idx, self.aa_block_size, pos=pos
|
311 |
+
)
|
312 |
+
else:
|
313 |
+
raise ValueError(f"Unknown attention type: {attn_type}")
|
314 |
+
|
315 |
+
|
316 |
+
# for frame_inter, global_inter in zip(frame_intermediates, global_intermediates):
|
317 |
+
# concat_inter = torch.cat([frame_inter, global_inter], dim=-1) # [B x S x P x 2C]
|
318 |
+
# output_list.append(concat_inter)
|
319 |
+
|
320 |
+
for i in range(len(frame_intermediates)):
|
321 |
+
# [B x S x P x 2C]
|
322 |
+
concat_inter = torch.cat([frame_intermediates[i], global_intermediates[i]], dim=-1)
|
323 |
+
output_list.append(concat_inter)
|
324 |
+
|
325 |
+
|
326 |
+
del concat_inter
|
327 |
+
del frame_intermediates
|
328 |
+
del global_intermediates
|
329 |
+
return output_list, None, self.patch_start_idx
|
330 |
+
|
331 |
+
|
332 |
+
|
333 |
+
def _process_frame_attention(self, tokens, B, S, P, C, frame_idx, num_blocks, pos=None):
|
334 |
+
"""
|
335 |
+
Process frame attention blocks.
|
336 |
+
"""
|
337 |
+
if tokens.shape != (B*S, P, C):
|
338 |
+
tokens = tokens.view(B, S, P, C)
|
339 |
+
tokens = tokens.view(B*S, P, C)
|
340 |
+
|
341 |
+
if pos is not None and pos.shape != (B*S, P, 2):
|
342 |
+
pos = pos.view(B, S, P, 2)
|
343 |
+
pos = pos.view(B*S, P, 2)
|
344 |
+
|
345 |
+
intermediates = []
|
346 |
+
|
347 |
+
for _ in range(num_blocks):
|
348 |
+
if self.use_checkpoint:
|
349 |
+
tokens = checkpoint(self.frame_blocks[frame_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
350 |
+
else:
|
351 |
+
tokens = self.frame_blocks[frame_idx](tokens, pos=pos)
|
352 |
+
frame_idx += 1
|
353 |
+
intermediates.append(tokens.view(B, S, P, C))
|
354 |
+
|
355 |
+
return tokens, frame_idx, intermediates
|
356 |
+
|
357 |
+
def _process_global_attention(self, tokens, B, S, P, C, global_idx, num_blocks, pos=None):
|
358 |
+
"""
|
359 |
+
Process global attention blocks.
|
360 |
+
"""
|
361 |
+
# pose_embed
|
362 |
+
|
363 |
+
if tokens.shape != (B, S*P, C):
|
364 |
+
tokens = tokens.view(B, S, P, C)
|
365 |
+
|
366 |
+
|
367 |
+
############################################################
|
368 |
+
# Frame embedding
|
369 |
+
if "register" in self.embed_type:
|
370 |
+
embed_tokens = tokens[:, :, 1:2, ...].clone()
|
371 |
+
if "gauss" in self.embed_type:
|
372 |
+
embed_tokens = torch.randn((B, S, 1, C),device=tokens.device, dtype=tokens.dtype)
|
373 |
+
|
374 |
+
if self.embed_type != "no":
|
375 |
+
embed_tokens = F.normalize(embed_tokens, dim=-1)
|
376 |
+
|
377 |
+
if "mlp" in self.embed_type:
|
378 |
+
embed_tokens = self.register_mlp[global_idx](embed_tokens)
|
379 |
+
|
380 |
+
if "mlpnorm" in self.embed_type:
|
381 |
+
embed_tokens = F.normalize(embed_tokens, dim=-1)
|
382 |
+
if "all" in self.embed_type:
|
383 |
+
tokens = tokens + embed_tokens
|
384 |
+
elif "part" in self.embed_type:
|
385 |
+
tokens[:, :, self.patch_start_idx:] = tokens[:, :, self.patch_start_idx:] + embed_tokens
|
386 |
+
else:
|
387 |
+
assert self.embed_type == "no"
|
388 |
+
|
389 |
+
if "postnorm" in self.embed_type:
|
390 |
+
tokens = F.normalize(tokens, dim=-1)
|
391 |
+
# tokens = self.embed_norm(tokens)
|
392 |
+
############################################################
|
393 |
+
|
394 |
+
|
395 |
+
|
396 |
+
tokens = tokens.view(B, S*P, C)
|
397 |
+
|
398 |
+
if pos is not None and pos.shape != (B, S*P, 2):
|
399 |
+
pos = pos.view(B, S, P, 2)
|
400 |
+
pos = pos.view(B, S*P, 2)
|
401 |
+
|
402 |
+
intermediates = []
|
403 |
+
for _ in range(num_blocks):
|
404 |
+
if self.use_checkpoint:
|
405 |
+
tokens = checkpoint(self.global_blocks[global_idx], tokens, pos, use_reentrant=self.use_reentrant)
|
406 |
+
else:
|
407 |
+
tokens = self.global_blocks[global_idx](tokens, pos=pos)
|
408 |
+
global_idx += 1
|
409 |
+
intermediates.append(tokens.view(B, S, P, C))
|
410 |
+
|
411 |
+
return tokens, global_idx, intermediates
|
412 |
+
|
413 |
+
|
414 |
+
|
415 |
+
|
416 |
+
def slice_expand_and_flatten(token_tensor, B, S):
|
417 |
+
"""
|
418 |
+
1) Takes the first token (index=0) and the remaining tokens (index=1..S-1).
|
419 |
+
2) Expands them along batch dimension B.
|
420 |
+
3) Concatenates along the time/sequence dimension => (B, S, ...).
|
421 |
+
4) Flattens the first two dims to produce => (B*S, ...).
|
422 |
+
|
423 |
+
Args:
|
424 |
+
token_tensor: a tensor expected to have shape (1, S, ...) or (some_batch, S, ...).
|
425 |
+
We'll slice along dim=1.
|
426 |
+
B: batch size.
|
427 |
+
S: number of frames/time-steps.
|
428 |
+
|
429 |
+
Returns:
|
430 |
+
Flattened token tensor of shape (B*S, ...).
|
431 |
+
"""
|
432 |
+
|
433 |
+
# Slice out the "query" tokens => shape (1, 1, ...)
|
434 |
+
query = token_tensor[:, 0:1, ...].expand(B, 1, *token_tensor.shape[2:])
|
435 |
+
# Slice out the "other" tokens => shape (1, S-1, ...)
|
436 |
+
others = token_tensor[:, 1:, ...].expand(B, S - 1, *token_tensor.shape[2:])
|
437 |
+
# Concatenate => shape (B, S, ...)
|
438 |
+
combined = torch.cat([query, others], dim=1)
|
439 |
+
|
440 |
+
# Finally flatten => shape (B*S, ...)
|
441 |
+
combined = combined.view(B * S, *combined.shape[2:])
|
442 |
+
return combined
|
443 |
+
|
444 |
+
|
445 |
+
|
446 |
+
|
447 |
+
def dino_to_aggregator(dinov2_weights):
|
448 |
+
new_dinov2_weights = {}
|
449 |
+
for key, value in dinov2_weights.items():
|
450 |
+
if "blocks" in key:
|
451 |
+
for new_attn_key in ["frame_blocks", "global_blocks"]:
|
452 |
+
new_key = key.replace("blocks", new_attn_key)
|
453 |
+
# if 'attn' in key:
|
454 |
+
# if "qkv.weight" in key:
|
455 |
+
# new_key = new_key.replace('qkv.weight', 'in_proj_weight')
|
456 |
+
# elif "qkv.bias" in key:
|
457 |
+
# new_key = new_key.replace('qkv.bias', 'in_proj_bias')
|
458 |
+
# elif 'proj.weight' in key:
|
459 |
+
# new_key = new_key.replace('proj.weight', 'out_proj.weight')
|
460 |
+
# elif 'proj.bias' in key:
|
461 |
+
# new_key = new_key.replace('proj.bias', 'out_proj.bias')
|
462 |
+
new_dinov2_weights[new_key] = value.clone()
|
463 |
+
return new_dinov2_weights
|
464 |
+
|
465 |
+
|
466 |
+
|
467 |
+
|
468 |
+
def remove_if_not_match(model_state_dict, state_dict, key):
|
469 |
+
if key in state_dict.keys() and key in model_state_dict.keys():
|
470 |
+
if state_dict[key].shape != model_state_dict[key].shape:
|
471 |
+
print(f"Warning: {key} shape mismatch, removing it")
|
472 |
+
del state_dict[key]
|
473 |
+
return state_dict
|
vggt/models/vggt.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import os
|
5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
6 |
+
|
7 |
+
# from off3d.models.vggt.utils import random_mask_single_patch_vectorized # Removed unused import
|
8 |
+
from hydra.utils import instantiate
|
9 |
+
# from .loss import *
|
10 |
+
|
11 |
+
def configure_dict(module, **attributes):
|
12 |
+
if module:
|
13 |
+
for attr, value in attributes.items():
|
14 |
+
setattr(module, attr, value)
|
15 |
+
|
16 |
+
|
17 |
+
class VGGT(nn.Module):
|
18 |
+
def __init__(self,
|
19 |
+
AGGREGATOR: Dict,
|
20 |
+
CameraHead: Dict,
|
21 |
+
PointHead: Dict,
|
22 |
+
DepthHead: Dict,
|
23 |
+
MatchHead: Dict,
|
24 |
+
TrackHead: Dict,
|
25 |
+
num_register_tokens,
|
26 |
+
init_values,
|
27 |
+
qk_norm,
|
28 |
+
ffn_layer,
|
29 |
+
patch_size,
|
30 |
+
enable_head_mp=False,
|
31 |
+
**kwargs):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
config_attrs = {
|
35 |
+
'patch_size': patch_size,
|
36 |
+
'init_values': init_values,
|
37 |
+
'qk_norm': qk_norm,
|
38 |
+
'ffn_layer': ffn_layer,
|
39 |
+
'num_register_tokens': num_register_tokens
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
if AGGREGATOR:
|
44 |
+
configure_dict(AGGREGATOR, **config_attrs)
|
45 |
+
self.aggregator = instantiate(AGGREGATOR, _recursive_=False)
|
46 |
+
else:
|
47 |
+
self.aggregator = None
|
48 |
+
|
49 |
+
if CameraHead:
|
50 |
+
configure_dict(CameraHead, **config_attrs)
|
51 |
+
CameraHead.loss_kwargs.pose_encoding_type = CameraHead.pose_encoding_type
|
52 |
+
self.camera_head_loss_kwargs = CameraHead.loss_kwargs
|
53 |
+
self.camera_head = instantiate(CameraHead, _recursive_=False)
|
54 |
+
else:
|
55 |
+
self.camera_head = None
|
56 |
+
|
57 |
+
if PointHead:
|
58 |
+
configure_dict(PointHead, **config_attrs)
|
59 |
+
self.point_head_loss_kwargs = PointHead.loss_kwargs
|
60 |
+
self.point_head = instantiate(PointHead, _recursive_=False)
|
61 |
+
else:
|
62 |
+
self.point_head = None
|
63 |
+
|
64 |
+
if DepthHead:
|
65 |
+
configure_dict(DepthHead, **config_attrs)
|
66 |
+
self.depth_head_loss_kwargs = DepthHead.loss_kwargs
|
67 |
+
self.depth_head = instantiate(DepthHead, _recursive_=False)
|
68 |
+
else:
|
69 |
+
self.depth_head = None
|
70 |
+
|
71 |
+
if MatchHead:
|
72 |
+
configure_dict(MatchHead, **config_attrs)
|
73 |
+
self.match_head_loss_kwargs = MatchHead.loss_kwargs
|
74 |
+
self.match_head = instantiate(MatchHead, _recursive_=False)
|
75 |
+
else:
|
76 |
+
self.match_head = None
|
77 |
+
|
78 |
+
if TrackHead:
|
79 |
+
configure_dict(TrackHead, **config_attrs)
|
80 |
+
self.track_head_loss_kwargs = TrackHead.loss_kwargs
|
81 |
+
self.track_head = instantiate(TrackHead, _recursive_=False)
|
82 |
+
else:
|
83 |
+
self.track_head = None
|
84 |
+
|
85 |
+
self.enable_head_mp = enable_head_mp
|
86 |
+
# self.mask_patch_ratio = mask_patch_ratio
|
87 |
+
# self.mask_patch_size = mask_patch_size
|
88 |
+
|
89 |
+
|
90 |
+
def forward(self, batch, device=None):
|
91 |
+
images = (batch["images"]) #.to(device) # B x S x 3 x H x W
|
92 |
+
# intrinsics = (batch["intrinsics"])#.to(device)
|
93 |
+
# extrinsics = (batch["extrinsics"])#.to(device)
|
94 |
+
B, S, C, H, W = images.shape
|
95 |
+
|
96 |
+
|
97 |
+
# if self.training and self.mask_patch_ratio > 0: # Commented out masking
|
98 |
+
# for _ in range(1000):
|
99 |
+
# print("Please do not use mask_patch_ratio for now")
|
100 |
+
|
101 |
+
# predictions = {} # Removed redundant dict
|
102 |
+
|
103 |
+
aggregated_tokens_list, _, patch_start_idx = self.aggregator(images, batch=batch)
|
104 |
+
|
105 |
+
|
106 |
+
# Pose branch
|
107 |
+
# TODO check pose encoding conversion # Removed TODO
|
108 |
+
# loss = 0
|
109 |
+
|
110 |
+
|
111 |
+
predictions = {}
|
112 |
+
|
113 |
+
|
114 |
+
|
115 |
+
# well by default we use amp for track head
|
116 |
+
if self.track_head is not None:
|
117 |
+
track_loss_dict = self.track_head(aggregated_tokens_list, batch=batch, patch_start_idx=patch_start_idx)
|
118 |
+
predictions.update(track_loss_dict)
|
119 |
+
|
120 |
+
|
121 |
+
with torch.cuda.amp.autocast(enabled=self.enable_head_mp):
|
122 |
+
if self.camera_head is not None:
|
123 |
+
pred_pose_enc_list = self.camera_head(aggregated_tokens_list, batch=batch, patch_start_idx=patch_start_idx)
|
124 |
+
camera_loss_dict = {}
|
125 |
+
camera_loss_dict["pred_extrinsic_list"] = pred_pose_enc_list
|
126 |
+
# with torch.cuda.amp.autocast(enabled=False):
|
127 |
+
# if not isinstance(pred_pose_enc_list, dict):
|
128 |
+
# camera_loss_dict, last_pred_extrinsic = camera_loss(pred_pose_enc_list, batch, **self.camera_head_loss_kwargs)
|
129 |
+
# predictions["pred_extrinsic"] = last_pred_extrinsic
|
130 |
+
# else:
|
131 |
+
# camera_loss_dict = pred_pose_enc_list
|
132 |
+
predictions.update(camera_loss_dict)
|
133 |
+
|
134 |
+
if self.point_head is not None:
|
135 |
+
pts3d, pts3d_conf = self.point_head(aggregated_tokens_list, batch=batch, patch_start_idx=patch_start_idx)
|
136 |
+
# with torch.cuda.amp.autocast(enabled=False):
|
137 |
+
# pts3d_loss_dict = point_loss(pts3d, pts3d_conf, batch, **self.point_head_loss_kwargs)
|
138 |
+
# predictions.update(pts3d_loss_dict)
|
139 |
+
predictions["pred_world_points"] = pts3d
|
140 |
+
predictions["pred_world_points_conf"] = pts3d_conf
|
141 |
+
|
142 |
+
if self.depth_head is not None:
|
143 |
+
depth, depth_conf = self.depth_head(aggregated_tokens_list, batch=batch, patch_start_idx=patch_start_idx)
|
144 |
+
# with torch.cuda.amp.autocast(enabled=False):
|
145 |
+
# depth_loss_dict = depth_loss(depth, depth_conf, batch, **self.depth_head_loss_kwargs)
|
146 |
+
# predictions.update(depth_loss_dict)
|
147 |
+
predictions["pred_depth"] = depth
|
148 |
+
predictions["pred_depth_conf"] = depth_conf
|
149 |
+
|
150 |
+
if self.match_head is not None:
|
151 |
+
match_loss_dict = self.match_head(aggregated_tokens_list, batch=batch, patch_start_idx=patch_start_idx)
|
152 |
+
predictions.update(match_loss_dict)
|
153 |
+
|
154 |
+
predictions.update(batch)
|
155 |
+
|
156 |
+
return predictions
|
vggt/utils/pose_enc.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .rotation import quat_to_mat, mat_to_quat
|
3 |
+
# from off3d.utils.metric import closed_form_inverse_OpenCV
|
4 |
+
|
5 |
+
|
6 |
+
def extri_intri_to_pose_encoding(
|
7 |
+
extrinsics,
|
8 |
+
intrinsics,
|
9 |
+
image_size_hw = None, # e.g., (256, 512)
|
10 |
+
pose_encoding_type="absT_quaR_FoV",
|
11 |
+
min_focal_length=0.1,
|
12 |
+
max_focal_length=10,):
|
13 |
+
|
14 |
+
# extrinsics: BxSx3x4
|
15 |
+
# intrinsics: BxSx3x3
|
16 |
+
|
17 |
+
|
18 |
+
if pose_encoding_type=="absT_quaR_FoV":
|
19 |
+
R = extrinsics[:, :, :3, :3] # BxSx3x3
|
20 |
+
T = extrinsics[:, :, :3, 3] # BxSx3
|
21 |
+
|
22 |
+
quat = mat_to_quat(R)
|
23 |
+
# R_reverse = quat_to_mat(quat)
|
24 |
+
# Note the order of h and w here
|
25 |
+
H, W = image_size_hw
|
26 |
+
fov_h = 2 * torch.atan((H /2) / intrinsics[..., 1, 1])
|
27 |
+
fov_w = 2 * torch.atan((W /2) / intrinsics[..., 0, 0])
|
28 |
+
pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
|
29 |
+
elif pose_encoding_type=="absT_quaR_OneFLM1":
|
30 |
+
# raise ValueError("Not checked after mitigrating to off3d.")
|
31 |
+
focal_length = intrinsics[:, :, [0,1], [0,1]] / max(image_size_hw)
|
32 |
+
focal_length = focal_length.mean(dim=-1)
|
33 |
+
focal_length = focal_length.clamp(min_focal_length, max_focal_length)
|
34 |
+
focal_length = focal_length - 1
|
35 |
+
R = extrinsics[:, :, :3, :3]
|
36 |
+
T = extrinsics[:, :, :3, 3]
|
37 |
+
quat = mat_to_quat(R)
|
38 |
+
pose_encoding = torch.cat([T, quat, focal_length[..., None]], dim=-1).float()
|
39 |
+
else:
|
40 |
+
raise NotImplementedError
|
41 |
+
|
42 |
+
return pose_encoding
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
def pose_encoding_to_extri_intri(
|
47 |
+
pose_encoding,
|
48 |
+
image_size_hw=None, # e.g., (256, 512)
|
49 |
+
min_focal_length=0.1,
|
50 |
+
max_focal_length=10,
|
51 |
+
pose_encoding_type="absT_quaR_FoV",
|
52 |
+
build_intrinsics=True):
|
53 |
+
|
54 |
+
intrinsics = None
|
55 |
+
|
56 |
+
if pose_encoding_type == "absT_quaR_FoV":
|
57 |
+
T = pose_encoding[..., :3]
|
58 |
+
quat = pose_encoding[..., 3:7]
|
59 |
+
fov_h = pose_encoding[..., 7]
|
60 |
+
fov_w = pose_encoding[..., 8]
|
61 |
+
|
62 |
+
R = quat_to_mat(quat)
|
63 |
+
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
64 |
+
|
65 |
+
if build_intrinsics:
|
66 |
+
H, W = image_size_hw
|
67 |
+
fy = (H / 2.0) / torch.tan(fov_h / 2.0)
|
68 |
+
fx = (W / 2.0) / torch.tan(fov_w / 2.0)
|
69 |
+
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
|
70 |
+
intrinsics[..., 0, 0] = fx
|
71 |
+
intrinsics[..., 1, 1] = fy
|
72 |
+
intrinsics[..., 0, 2] = W / 2
|
73 |
+
intrinsics[..., 1, 2] = H / 2
|
74 |
+
intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
|
75 |
+
elif pose_encoding_type == "absT_quaR_OneFLM1":
|
76 |
+
T = pose_encoding[..., :3]
|
77 |
+
quat = pose_encoding[..., 3:7]
|
78 |
+
focal_length_encoded = pose_encoding[..., 7]
|
79 |
+
focal_length = (focal_length_encoded + 1).clamp(min_focal_length, max_focal_length)
|
80 |
+
focal_length = focal_length * max(image_size_hw)
|
81 |
+
R = quat_to_mat(quat)
|
82 |
+
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
83 |
+
|
84 |
+
if build_intrinsics:
|
85 |
+
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
|
86 |
+
intrinsics[..., 0, 0] = focal_length
|
87 |
+
intrinsics[..., 1, 1] = focal_length
|
88 |
+
intrinsics[..., 0, 2] = image_size_hw[1] / 2
|
89 |
+
intrinsics[..., 1, 2] = image_size_hw[0] / 2
|
90 |
+
|
91 |
+
# NOTE something is wrong here
|
92 |
+
intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
|
93 |
+
# TODO fill the principle point here, I need to check it is hw or wh
|
94 |
+
else:
|
95 |
+
raise NotImplementedError
|
96 |
+
|
97 |
+
return extrinsics, intrinsics
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
def test_pose_encoding():
|
103 |
+
num_tests = 1000
|
104 |
+
batch_size = 4
|
105 |
+
num_cameras = 2
|
106 |
+
image_size_hw = (256, 512)
|
107 |
+
min_focal_length = 0.1
|
108 |
+
max_focal_length = 30
|
109 |
+
pose_encoding_type = "absT_quaR_OneFLM1"
|
110 |
+
|
111 |
+
for _ in range(num_tests):
|
112 |
+
# Generate random extrinsics and intrinsics
|
113 |
+
pose_encoding = torch.randn(batch_size, num_cameras, 8)
|
114 |
+
|
115 |
+
# converting forward and backward, and verifying the consistency
|
116 |
+
extrinsics, intrinsics = pose_encoding_to_extri_intri(pose_encoding, image_size_hw, min_focal_length, max_focal_length, pose_encoding_type)
|
117 |
+
pose_encoding_back = extri_intri_to_pose_encoding(extrinsics, intrinsics, image_size_hw, pose_encoding_type, min_focal_length, max_focal_length)
|
118 |
+
extrinsics_forward, intrinsics_forward = pose_encoding_to_extri_intri(pose_encoding_back, image_size_hw, min_focal_length, max_focal_length, pose_encoding_type)
|
119 |
+
pose_encoding_forward = extri_intri_to_pose_encoding(extrinsics_forward, intrinsics_forward, image_size_hw, pose_encoding_type, min_focal_length, max_focal_length)
|
120 |
+
assert torch.allclose(pose_encoding_forward[..., :7], pose_encoding_back[..., :7], atol=1e-5), "Pose encoding does not match!"
|
121 |
+
print("All tests passed!")
|
122 |
+
|
123 |
+
if __name__ == "__main__":
|
124 |
+
test_pose_encoding()
|
125 |
+
|
126 |
+
|
vggt/utils/rotation.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from PyTorch3D
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from scipy.spatial.transform import Rotation as R
|
7 |
+
|
8 |
+
|
9 |
+
|
10 |
+
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
11 |
+
"""
|
12 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
13 |
+
|
14 |
+
Convert rotations given as quaternions to rotation matrices.
|
15 |
+
Args:
|
16 |
+
quaternions: quaternions with real part last,
|
17 |
+
as tensor of shape (..., 4).
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
21 |
+
"""
|
22 |
+
i, j, k, r = torch.unbind(quaternions, -1)
|
23 |
+
# pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
|
24 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
25 |
+
|
26 |
+
o = torch.stack(
|
27 |
+
(
|
28 |
+
1 - two_s * (j * j + k * k),
|
29 |
+
two_s * (i * j - k * r),
|
30 |
+
two_s * (i * k + j * r),
|
31 |
+
two_s * (i * j + k * r),
|
32 |
+
1 - two_s * (i * i + k * k),
|
33 |
+
two_s * (j * k - i * r),
|
34 |
+
two_s * (i * k - j * r),
|
35 |
+
two_s * (j * k + i * r),
|
36 |
+
1 - two_s * (i * i + j * j),
|
37 |
+
),
|
38 |
+
-1,
|
39 |
+
)
|
40 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
41 |
+
|
42 |
+
|
43 |
+
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
44 |
+
"""
|
45 |
+
Convert rotations given as rotation matrices to quaternions.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
quaternions with real part last, as tensor of shape (..., 4).
|
52 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
53 |
+
"""
|
54 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
55 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
56 |
+
|
57 |
+
batch_dim = matrix.shape[:-2]
|
58 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
59 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
60 |
+
)
|
61 |
+
|
62 |
+
q_abs = _sqrt_positive_part(
|
63 |
+
torch.stack(
|
64 |
+
[
|
65 |
+
1.0 + m00 + m11 + m22,
|
66 |
+
1.0 + m00 - m11 - m22,
|
67 |
+
1.0 - m00 + m11 - m22,
|
68 |
+
1.0 - m00 - m11 + m22,
|
69 |
+
],
|
70 |
+
dim=-1,
|
71 |
+
)
|
72 |
+
)
|
73 |
+
|
74 |
+
# we produce the desired quaternion multiplied by each of r, i, j, k
|
75 |
+
quat_by_rijk = torch.stack(
|
76 |
+
[
|
77 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
78 |
+
# `int`.
|
79 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
80 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
81 |
+
# `int`.
|
82 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
83 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
84 |
+
# `int`.
|
85 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
86 |
+
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
87 |
+
# `int`.
|
88 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
89 |
+
],
|
90 |
+
dim=-2,
|
91 |
+
)
|
92 |
+
|
93 |
+
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
94 |
+
# the candidate won't be picked.
|
95 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
96 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
97 |
+
|
98 |
+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
99 |
+
# forall i; we pick the best-conditioned one (with the largest denominator)
|
100 |
+
out = quat_candidates[
|
101 |
+
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
102 |
+
].reshape(batch_dim + (4,))
|
103 |
+
|
104 |
+
# Convert from rijk to ijkr
|
105 |
+
out = out[..., [1, 2, 3, 0]]
|
106 |
+
|
107 |
+
out = standardize_quaternion(out)
|
108 |
+
|
109 |
+
return out
|
110 |
+
|
111 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
112 |
+
"""
|
113 |
+
Returns torch.sqrt(torch.max(0, x))
|
114 |
+
but with a zero subgradient where x is 0.
|
115 |
+
"""
|
116 |
+
ret = torch.zeros_like(x)
|
117 |
+
positive_mask = x > 0
|
118 |
+
if torch.is_grad_enabled():
|
119 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
120 |
+
else:
|
121 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
122 |
+
return ret
|
123 |
+
|
124 |
+
|
125 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
126 |
+
"""
|
127 |
+
Convert a unit quaternion to a standard form: one in which the real
|
128 |
+
part is non negative.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
quaternions: Quaternions with real part last,
|
132 |
+
as tensor of shape (..., 4).
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
Standardized quaternions as tensor of shape (..., 4).
|
136 |
+
"""
|
137 |
+
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
138 |
+
|
139 |
+
|
140 |
+
def quat_to_mat_scipy(quaternions: np.ndarray) -> np.ndarray:
|
141 |
+
rotation = R.from_quat(quaternions)
|
142 |
+
return rotation.as_matrix()
|
143 |
+
|
144 |
+
def mat_to_quat_scipy(matrix: np.ndarray) -> np.ndarray:
|
145 |
+
rotation = R.from_matrix(matrix)
|
146 |
+
return rotation.as_quat()
|
147 |
+
|
148 |
+
|
149 |
+
if __name__ == "__main__":
|
150 |
+
|
151 |
+
num_tests = 10000 # Number of tests to run
|
152 |
+
tolerance = 1e-6 # Tolerance for floating point comparison
|
153 |
+
|
154 |
+
for _ in range(num_tests):
|
155 |
+
# Generate random quaternions
|
156 |
+
quaternions = torch.randn(1024, 4)
|
157 |
+
quaternions = quaternions / torch.norm(quaternions, dim=-1, keepdim=True) # Normalize to unit quaternions
|
158 |
+
|
159 |
+
# Convert quaternion to matrix using PyTorch
|
160 |
+
matrices_torch = quat_to_mat(quaternions)
|
161 |
+
|
162 |
+
# Convert matrices back to quaternions using PyTorch
|
163 |
+
quaternions_back = mat_to_quat(matrices_torch)
|
164 |
+
|
165 |
+
# Standardize quaternions to handle the case where quaternions = -quaternions_back
|
166 |
+
quaternions = standardize_quaternion(quaternions)
|
167 |
+
quaternions_back = standardize_quaternion(quaternions_back)
|
168 |
+
|
169 |
+
# Check if the original and converted quaternions match
|
170 |
+
if not torch.allclose(quaternions, quaternions_back, atol=tolerance):
|
171 |
+
print("Mismatch found!")
|
172 |
+
print("Original quaternions:", quaternions)
|
173 |
+
print("Converted quaternions:", quaternions_back)
|
174 |
+
max_error = torch.max(torch.abs(quaternions - quaternions_back))
|
175 |
+
print("Max error:", max_error)
|
176 |
+
else:
|
177 |
+
print("All tests passed successfully!")
|
178 |
+
|
179 |
+
# write code here
|
180 |
+
|
181 |
+
# quaternions = torch.randn(1024, 4) * 20
|
182 |
+
# # quaternions = quaternions / torch.norm(quaternions, dim=-1, keepdim=True) # Normalize to unit quaternions
|
183 |
+
|
184 |
+
# # Convert quaternion to matrix using PyTorch
|
185 |
+
# matrices_torch = quat_to_mat(quaternions).numpy()
|
186 |
+
|
187 |
+
# # Convert quaternion to matrix using SciPy
|
188 |
+
# matrices_scipy = quat_to_mat_scipy(quaternions.numpy())
|
189 |
+
|
190 |
+
# # Convert matrices back to quaternions using PyTorch
|
191 |
+
# quaternions_torch = mat_to_quat(torch.from_numpy(matrices_scipy)).numpy()
|
192 |
+
|
193 |
+
# # Convert matrices back to quaternions using SciPy
|
194 |
+
# quaternions_scipy = mat_to_quat_scipy(matrices_torch)
|
195 |
+
|
196 |
+
|
197 |
+
# reconvert_mat_diff = quat_to_mat_scipy(quaternions_torch) - quat_to_mat_scipy(quaternions_scipy)
|
198 |
+
# # Compare results
|
199 |
+
# print("Matrix conversion difference:", np.linalg.norm(matrices_torch - matrices_scipy))
|
200 |
+
# print("Quaternion conversion difference:", np.linalg.norm(reconvert_mat_diff))
|
viser_fn.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Visualization utilities for 3D reconstruction results using Viser.
|
2 |
+
|
3 |
+
Provides tools to visualize predicted camera poses, 3D point clouds, and confidence
|
4 |
+
thresholding through an interactive web interface.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import time
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import List, Optional
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import tyro
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
import cv2
|
15 |
+
import viser
|
16 |
+
import viser.transforms as tf
|
17 |
+
import glob
|
18 |
+
import os
|
19 |
+
from scipy.spatial.transform import Rotation as R
|
20 |
+
# from camera import closed_form_inverse_se3
|
21 |
+
import torch
|
22 |
+
import threading
|
23 |
+
|
24 |
+
def viser_wrapper(
|
25 |
+
pred_dict: dict,
|
26 |
+
port: int = None,
|
27 |
+
init_conf_threshold: float = 3.0,
|
28 |
+
) -> None:
|
29 |
+
"""Visualize
|
30 |
+
Args:
|
31 |
+
pred_dict: Dictionary containing predictions
|
32 |
+
port: Optional port number for the viser server. If None, a random port will be used.
|
33 |
+
"""
|
34 |
+
print(f"Starting viser server on port {port}") # Debug print
|
35 |
+
|
36 |
+
server = viser.ViserServer(host="0.0.0.0", port=port)
|
37 |
+
# server = viser.ViserServer(port=port)
|
38 |
+
server.gui.configure_theme(titlebar_content=None, control_layout="collapsible")
|
39 |
+
|
40 |
+
# Unpack and preprocess inputs
|
41 |
+
images = pred_dict["images"]
|
42 |
+
world_points = pred_dict["pred_world_points"]
|
43 |
+
conf = pred_dict["pred_world_points_conf"]
|
44 |
+
extrinsics = pred_dict["last_pred_extrinsic"]
|
45 |
+
|
46 |
+
# Handle batch dimension if present
|
47 |
+
if len(images.shape) > 4:
|
48 |
+
images = images[0]
|
49 |
+
world_points = world_points[0]
|
50 |
+
conf = conf[0]
|
51 |
+
extrinsics = extrinsics[0]
|
52 |
+
|
53 |
+
colors = images.transpose(0, 2, 3, 1) # Convert to (B, H, W, C)
|
54 |
+
|
55 |
+
# Reshape for visualization
|
56 |
+
S, H, W, _ = world_points.shape
|
57 |
+
colors = (colors.reshape(-1, 3) * 255).astype(np.uint8) # Convert to 0-255 range
|
58 |
+
conf = conf.reshape(-1)
|
59 |
+
world_points = world_points.reshape(-1, 3)
|
60 |
+
|
61 |
+
# Calculate camera poses in world coordinates
|
62 |
+
cam_to_world = closed_form_inverse_se3(extrinsics)
|
63 |
+
extrinsics = cam_to_world[:, :3, :]
|
64 |
+
|
65 |
+
# Center scene for better visualization
|
66 |
+
scene_center = np.mean(world_points, axis=0)
|
67 |
+
world_points -= scene_center
|
68 |
+
extrinsics[..., -1] -= scene_center
|
69 |
+
|
70 |
+
# set points3d as world_points
|
71 |
+
points = world_points
|
72 |
+
|
73 |
+
|
74 |
+
# frame_mask
|
75 |
+
|
76 |
+
frame_indices = np.arange(S)
|
77 |
+
frame_indices = frame_indices[:, None, None] # Shape: (S, 1, 1, 1)
|
78 |
+
frame_indices = np.tile(frame_indices, (1, H, W)) # Shape: (S, H, W, 3)
|
79 |
+
frame_indices = frame_indices.reshape(-1)
|
80 |
+
|
81 |
+
############################################################
|
82 |
+
############################################################
|
83 |
+
|
84 |
+
|
85 |
+
|
86 |
+
gui_points_conf = server.gui.add_slider(
|
87 |
+
"Confidence Thres",
|
88 |
+
min=0.1,
|
89 |
+
max=20,
|
90 |
+
step=0.05,
|
91 |
+
initial_value=init_conf_threshold,
|
92 |
+
)
|
93 |
+
|
94 |
+
|
95 |
+
|
96 |
+
gui_point_size = server.gui.add_slider(
|
97 |
+
"Point size", min=0.00001, max=0.01, step=0.0001, initial_value=0.00001
|
98 |
+
)
|
99 |
+
|
100 |
+
# Change from "Frame Selector" to more descriptive name
|
101 |
+
gui_frame_selector = server.gui.add_dropdown(
|
102 |
+
"Filter by Frame", # More action-oriented name
|
103 |
+
options=["All"] + [str(i) for i in range(S)],
|
104 |
+
initial_value="All",
|
105 |
+
)
|
106 |
+
|
107 |
+
# Initial mask shows all points passing confidence threshold
|
108 |
+
init_conf_mask = conf > init_conf_threshold
|
109 |
+
point_cloud = server.scene.add_point_cloud(
|
110 |
+
name="viser_pcd",
|
111 |
+
points=points[init_conf_mask],
|
112 |
+
colors=colors[init_conf_mask],
|
113 |
+
point_size=gui_point_size.value,
|
114 |
+
point_shape="circle",
|
115 |
+
)
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
frames: List[viser.FrameHandle] = []
|
120 |
+
|
121 |
+
def visualize_frames(extrinsics: np.ndarray, intrinsics: np.ndarray, images: np.ndarray) -> None:
|
122 |
+
"""Send all COLMAP elements to viser for visualization. This could be optimized
|
123 |
+
a ton!"""
|
124 |
+
extrinsics = np.copy(extrinsics)
|
125 |
+
# Remove existing image frames.
|
126 |
+
for frame in frames:
|
127 |
+
frame.remove()
|
128 |
+
frames.clear()
|
129 |
+
|
130 |
+
|
131 |
+
def attach_callback(
|
132 |
+
frustum: viser.CameraFrustumHandle, frame: viser.FrameHandle
|
133 |
+
) -> None:
|
134 |
+
@frustum.on_click
|
135 |
+
def _(_) -> None:
|
136 |
+
for client in server.get_clients().values():
|
137 |
+
client.camera.wxyz = frame.wxyz
|
138 |
+
client.camera.position = frame.position
|
139 |
+
|
140 |
+
img_ids = sorted(range(S))
|
141 |
+
for img_id in tqdm(img_ids):
|
142 |
+
|
143 |
+
cam_to_world = extrinsics[img_id]
|
144 |
+
|
145 |
+
T_world_camera = tf.SE3.from_matrix(cam_to_world)
|
146 |
+
|
147 |
+
ratio = 1
|
148 |
+
frame = server.scene.add_frame(
|
149 |
+
f"frame_{img_id}",
|
150 |
+
wxyz=T_world_camera.rotation().wxyz,
|
151 |
+
position=T_world_camera.translation(),
|
152 |
+
axes_length=0.05/ratio,
|
153 |
+
axes_radius=0.002/ratio,
|
154 |
+
origin_radius = 0.002/ratio
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
frames.append(frame)
|
159 |
+
|
160 |
+
img = images[img_id]
|
161 |
+
img = (img.transpose(1, 2, 0) * 255).astype(np.uint8)
|
162 |
+
# import pdb;pdb.set_trace()
|
163 |
+
H, W = img.shape[:2]
|
164 |
+
# fy = intrinsics[img_id, 1, 1] * H
|
165 |
+
fy = 1.1 * H
|
166 |
+
image = img
|
167 |
+
# image = image[::downsample_factor, ::downsample_factor]
|
168 |
+
frustum = server.scene.add_camera_frustum(
|
169 |
+
f"frame_{img_id}/frustum",
|
170 |
+
fov=2 * np.arctan2(H / 2, fy),
|
171 |
+
aspect=W / H,
|
172 |
+
scale=0.05/ratio,
|
173 |
+
image=image,
|
174 |
+
line_width=1.0,
|
175 |
+
# line_thickness=0.01,
|
176 |
+
)
|
177 |
+
|
178 |
+
attach_callback(frustum, frame)
|
179 |
+
|
180 |
+
|
181 |
+
@gui_points_conf.on_update
|
182 |
+
def _(_) -> None:
|
183 |
+
conf_mask = conf > gui_points_conf.value
|
184 |
+
frame_mask = np.ones_like(conf_mask) # Default to all frames
|
185 |
+
if gui_frame_selector.value != "All":
|
186 |
+
selected_idx = int(gui_frame_selector.value)
|
187 |
+
frame_mask = (frame_indices == selected_idx)
|
188 |
+
|
189 |
+
combined_mask = conf_mask & frame_mask
|
190 |
+
point_cloud.points = points[combined_mask]
|
191 |
+
point_cloud.colors = colors[combined_mask]
|
192 |
+
|
193 |
+
@gui_point_size.on_update
|
194 |
+
def _(_) -> None:
|
195 |
+
point_cloud.point_size = gui_point_size.value
|
196 |
+
|
197 |
+
@gui_frame_selector.on_update
|
198 |
+
def _(_) -> None:
|
199 |
+
"""Update points based on frame selection."""
|
200 |
+
conf_mask = conf > gui_points_conf.value
|
201 |
+
|
202 |
+
if gui_frame_selector.value == "All":
|
203 |
+
# Show all points passing confidence threshold
|
204 |
+
point_cloud.points = points[conf_mask]
|
205 |
+
point_cloud.colors = colors[conf_mask]
|
206 |
+
else:
|
207 |
+
# Show only selected frame's points
|
208 |
+
selected_idx = int(gui_frame_selector.value)
|
209 |
+
frame_mask = (frame_indices == selected_idx)
|
210 |
+
combined_mask = conf_mask & frame_mask
|
211 |
+
point_cloud.points = points[combined_mask]
|
212 |
+
point_cloud.colors = colors[combined_mask]
|
213 |
+
|
214 |
+
# Move camera to selected frame
|
215 |
+
# if 0 <= selected_idx < len(frames):
|
216 |
+
# selected_frame = frames[selected_idx]
|
217 |
+
# for client in server.get_clients().values():
|
218 |
+
# client.camera.wxyz = selected_frame.wxyz
|
219 |
+
# client.camera.position = selected_frame.position
|
220 |
+
|
221 |
+
|
222 |
+
# Initial visualization
|
223 |
+
visualize_frames(extrinsics, None, images)
|
224 |
+
|
225 |
+
# # Start server update loop in a background thread
|
226 |
+
def server_loop():
|
227 |
+
while True:
|
228 |
+
time.sleep(1e-3) # Small sleep to prevent CPU hogging
|
229 |
+
|
230 |
+
thread = threading.Thread(target=server_loop, daemon=True)
|
231 |
+
thread.start()
|
232 |
+
|
233 |
+
|
234 |
+
|
235 |
+
def closed_form_inverse_se3(se3, R=None, T=None):
|
236 |
+
"""
|
237 |
+
Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
|
238 |
+
|
239 |
+
If `R` and `T` are provided, they must correspond to the rotation and translation
|
240 |
+
components of `se3`. Otherwise, they will be extracted from `se3`.
|
241 |
+
|
242 |
+
Args:
|
243 |
+
se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
|
244 |
+
R (optional): Nx3x3 array or tensor of rotation matrices.
|
245 |
+
T (optional): Nx3x1 array or tensor of translation vectors.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
Inverted SE3 matrices with the same type and device as `se3`.
|
249 |
+
|
250 |
+
Shapes:
|
251 |
+
se3: (N, 4, 4)
|
252 |
+
R: (N, 3, 3)
|
253 |
+
T: (N, 3, 1)
|
254 |
+
"""
|
255 |
+
# Check if se3 is a numpy array or a torch tensor
|
256 |
+
is_numpy = isinstance(se3, np.ndarray)
|
257 |
+
|
258 |
+
# Validate shapes
|
259 |
+
if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
|
260 |
+
raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
|
261 |
+
|
262 |
+
# Extract R and T if not provided
|
263 |
+
if R is None:
|
264 |
+
R = se3[:, :3, :3] # (N,3,3)
|
265 |
+
if T is None:
|
266 |
+
T = se3[:, :3, 3:] # (N,3,1)
|
267 |
+
|
268 |
+
# Transpose R
|
269 |
+
if is_numpy:
|
270 |
+
# Compute the transpose of the rotation for NumPy
|
271 |
+
R_transposed = np.transpose(R, (0, 2, 1))
|
272 |
+
# -R^T t for NumPy
|
273 |
+
top_right = -np.matmul(R_transposed, T)
|
274 |
+
inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
|
275 |
+
else:
|
276 |
+
R_transposed = R.transpose(1, 2) # (N,3,3)
|
277 |
+
top_right = -torch.bmm(R_transposed, T) # (N,3,1)
|
278 |
+
inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
|
279 |
+
inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
|
280 |
+
|
281 |
+
inverted_matrix[:, :3, :3] = R_transposed
|
282 |
+
inverted_matrix[:, :3, 3:] = top_right
|
283 |
+
|
284 |
+
return inverted_matrix
|