Spaces:
Runtime error
Runtime error
tokenid
commited on
Commit
•
ad06aed
1
Parent(s):
6af576b
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +1 -0
- README.md +1 -1
- app.py +349 -0
- configs/instant-mesh-base.yaml +22 -0
- configs/instant-mesh-large.yaml +22 -0
- configs/instant-nerf-base.yaml +21 -0
- configs/instant-nerf-large.yaml +21 -0
- examples/bird.jpg +0 -0
- examples/bubble_mart_blue.png +0 -0
- examples/cake.jpg +0 -0
- examples/cartoon_dinosaur.png +0 -0
- examples/cartoon_girl.jpg +0 -0
- examples/chair_comfort.jpg +0 -0
- examples/chair_wood.jpg +0 -0
- examples/chest.jpg +0 -0
- examples/cube.png +0 -0
- examples/extinguisher.png +0 -0
- examples/fruit_bycycle.jpg +0 -0
- examples/fruit_elephant.jpg +0 -0
- examples/genshin_building.png +0 -0
- examples/house2.jpg +0 -0
- examples/kunkun.png +0 -0
- examples/mushroom_teapot.jpg +0 -0
- examples/pikachu.png +0 -0
- examples/pistol.png +0 -0
- examples/plant.jpg +0 -0
- examples/robot.jpg +0 -0
- examples/sea_turtle.png +0 -0
- examples/skating_shoe.jpg +0 -0
- examples/sorting_board.png +0 -0
- examples/sword.png +0 -0
- examples/toy_car.jpg +0 -0
- examples/toyduck.png +0 -0
- examples/watermelon.png +0 -0
- examples/whitedog.png +0 -0
- examples/x_cube.jpg +0 -0
- examples/x_teapot.jpg +0 -0
- examples/x_toyduck.jpg +0 -0
- requirements.txt +21 -0
- src/__init__.py +0 -0
- src/data/__init__.py +0 -0
- src/data/objaverse.py +329 -0
- src/model.py +310 -0
- src/model_mesh.py +325 -0
- src/models/__init__.py +0 -0
- src/models/decoder/__init__.py +0 -0
- src/models/decoder/transformer.py +123 -0
- src/models/encoder/__init__.py +0 -0
- src/models/encoder/dino.py +550 -0
- src/models/encoder/dino_wrapper.py +80 -0
.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__pycache__
|
README.md
CHANGED
@@ -7,7 +7,7 @@ sdk: gradio
|
|
7 |
sdk_version: 4.25.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
-
license:
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
7 |
sdk_version: 4.25.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import imageio
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import rembg
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision.transforms import v2
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
from einops import rearrange, repeat
|
11 |
+
from tqdm import tqdm
|
12 |
+
from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler
|
13 |
+
|
14 |
+
from src.utils.train_util import instantiate_from_config
|
15 |
+
from src.utils.camera_util import (
|
16 |
+
FOV_to_intrinsics,
|
17 |
+
get_zero123plus_input_cameras,
|
18 |
+
get_circular_camera_poses,
|
19 |
+
)
|
20 |
+
from src.utils.mesh_util import save_obj
|
21 |
+
from src.utils.infer_util import remove_background, resize_foreground, images_to_video
|
22 |
+
|
23 |
+
import tempfile
|
24 |
+
from functools import partial
|
25 |
+
|
26 |
+
from huggingface_hub import hf_hub_download
|
27 |
+
import spaces
|
28 |
+
|
29 |
+
|
30 |
+
def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False):
|
31 |
+
"""
|
32 |
+
Get the rendering camera parameters.
|
33 |
+
"""
|
34 |
+
c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation)
|
35 |
+
if is_flexicubes:
|
36 |
+
cameras = torch.linalg.inv(c2ws)
|
37 |
+
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1)
|
38 |
+
else:
|
39 |
+
extrinsics = c2ws.flatten(-2)
|
40 |
+
intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2)
|
41 |
+
cameras = torch.cat([extrinsics, intrinsics], dim=-1)
|
42 |
+
cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1)
|
43 |
+
return cameras
|
44 |
+
|
45 |
+
|
46 |
+
def images_to_video(images, output_path, fps=30):
|
47 |
+
# images: (N, C, H, W)
|
48 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
49 |
+
frames = []
|
50 |
+
for i in range(images.shape[0]):
|
51 |
+
frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8).clip(0, 255)
|
52 |
+
assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \
|
53 |
+
f"Frame shape mismatch: {frame.shape} vs {images.shape}"
|
54 |
+
assert frame.min() >= 0 and frame.max() <= 255, \
|
55 |
+
f"Frame value out of range: {frame.min()} ~ {frame.max()}"
|
56 |
+
frames.append(frame)
|
57 |
+
imageio.mimwrite(output_path, np.stack(frames), fps=fps, codec='h264')
|
58 |
+
|
59 |
+
|
60 |
+
###############################################################################
|
61 |
+
# Configuration.
|
62 |
+
###############################################################################
|
63 |
+
|
64 |
+
config_path = 'configs/instant-mesh-large-eval.yaml'
|
65 |
+
config = OmegaConf.load(config_path)
|
66 |
+
config_name = os.path.basename(config_path).replace('.yaml', '')
|
67 |
+
model_config = config.model_config
|
68 |
+
infer_config = config.infer_config
|
69 |
+
|
70 |
+
IS_FLEXICUBES = True if config_name.startswith('instant-mesh') else False
|
71 |
+
|
72 |
+
device = torch.device('cuda')
|
73 |
+
|
74 |
+
# load diffusion model
|
75 |
+
print('Loading diffusion model ...')
|
76 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
77 |
+
"sudo-ai/zero123plus-v1.2",
|
78 |
+
custom_pipeline="zero123plus",
|
79 |
+
torch_dtype=torch.float16,
|
80 |
+
)
|
81 |
+
pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(
|
82 |
+
pipeline.scheduler.config, timestep_spacing='trailing'
|
83 |
+
)
|
84 |
+
|
85 |
+
# load custom white-background UNet
|
86 |
+
unet_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="diffusion_pytorch_model.bin", repo_type="model")
|
87 |
+
state_dict = torch.load(unet_ckpt_path, map_location='cpu')
|
88 |
+
pipeline.unet.load_state_dict(state_dict, strict=True)
|
89 |
+
|
90 |
+
pipeline = pipeline.to(device)
|
91 |
+
|
92 |
+
# load reconstruction model
|
93 |
+
print('Loading reconstruction model ...')
|
94 |
+
model_ckpt_path = hf_hub_download(repo_id="TencentARC/InstantMesh", filename="instant_mesh_large.ckpt", repo_type="model")
|
95 |
+
model = instantiate_from_config(model_config)
|
96 |
+
state_dict = torch.load(model_ckpt_path, map_location='cpu')['state_dict']
|
97 |
+
state_dict = {k[14:]: v for k, v in state_dict.items() if k.startswith('lrm_generator.') and 'source_camera' not in k}
|
98 |
+
model.load_state_dict(state_dict, strict=True)
|
99 |
+
|
100 |
+
model = model.to(device)
|
101 |
+
if IS_FLEXICUBES:
|
102 |
+
model.init_flexicubes_geometry(device)
|
103 |
+
model = model.eval()
|
104 |
+
|
105 |
+
print('Loading Finished!')
|
106 |
+
|
107 |
+
|
108 |
+
def check_input_image(input_image):
|
109 |
+
if input_image is None:
|
110 |
+
raise gr.Error("No image uploaded!")
|
111 |
+
|
112 |
+
|
113 |
+
def preprocess(input_image, do_remove_background):
|
114 |
+
|
115 |
+
rembg_session = rembg.new_session() if do_remove_background else None
|
116 |
+
|
117 |
+
if do_remove_background:
|
118 |
+
input_image = remove_background(input_image, rembg_session)
|
119 |
+
input_image = resize_foreground(input_image, 0.85)
|
120 |
+
|
121 |
+
return input_image
|
122 |
+
|
123 |
+
|
124 |
+
@spaces.GPU
|
125 |
+
def generate_mvs(input_image, sample_steps, sample_seed):
|
126 |
+
|
127 |
+
seed_everything(sample_seed)
|
128 |
+
|
129 |
+
# sampling
|
130 |
+
z123_image = pipeline(
|
131 |
+
input_image,
|
132 |
+
num_inference_steps=sample_steps
|
133 |
+
).images[0]
|
134 |
+
|
135 |
+
show_image = np.asarray(z123_image, dtype=np.uint8)
|
136 |
+
show_image = torch.from_numpy(show_image) # (960, 640, 3)
|
137 |
+
show_image = rearrange(show_image, '(n h) (m w) c -> (m h) (n w) c', n=3, m=2)
|
138 |
+
show_image = Image.fromarray(show_image.numpy())
|
139 |
+
|
140 |
+
return z123_image, show_image
|
141 |
+
|
142 |
+
|
143 |
+
@spaces.GPU
|
144 |
+
def make_mesh(mesh_fpath, planes):
|
145 |
+
|
146 |
+
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
147 |
+
mesh_dirname = os.path.dirname(mesh_fpath)
|
148 |
+
mesh_vis_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
|
149 |
+
|
150 |
+
with torch.no_grad():
|
151 |
+
|
152 |
+
# get mesh
|
153 |
+
mesh_out = model.extract_mesh(
|
154 |
+
planes,
|
155 |
+
use_texture_map=False,
|
156 |
+
**infer_config,
|
157 |
+
)
|
158 |
+
|
159 |
+
vertices, faces, vertex_colors = mesh_out
|
160 |
+
vertices = vertices[:, [0, 2, 1]]
|
161 |
+
vertices[:, -1] *= -1
|
162 |
+
|
163 |
+
save_obj(vertices, faces, vertex_colors, mesh_fpath)
|
164 |
+
|
165 |
+
print(f"Mesh saved to {mesh_fpath}")
|
166 |
+
|
167 |
+
return mesh_fpath
|
168 |
+
|
169 |
+
|
170 |
+
@spaces.GPU
|
171 |
+
def make3d(images):
|
172 |
+
|
173 |
+
images = np.asarray(images, dtype=np.float32) / 255.0
|
174 |
+
images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640)
|
175 |
+
images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320)
|
176 |
+
|
177 |
+
input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=2.5).to(device)
|
178 |
+
render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device)
|
179 |
+
|
180 |
+
images = images.unsqueeze(0).to(device)
|
181 |
+
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
|
182 |
+
|
183 |
+
mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
|
184 |
+
print(mesh_fpath)
|
185 |
+
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
186 |
+
mesh_dirname = os.path.dirname(mesh_fpath)
|
187 |
+
video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4")
|
188 |
+
|
189 |
+
with torch.no_grad():
|
190 |
+
# get triplane
|
191 |
+
planes = model.forward_planes(images, input_cameras)
|
192 |
+
|
193 |
+
# get video
|
194 |
+
chunk_size = 20 if IS_FLEXICUBES else 1
|
195 |
+
render_size = 384
|
196 |
+
|
197 |
+
frames = []
|
198 |
+
for i in tqdm(range(0, render_cameras.shape[1], chunk_size)):
|
199 |
+
if IS_FLEXICUBES:
|
200 |
+
frame = model.forward_geometry(
|
201 |
+
planes,
|
202 |
+
render_cameras[:, i:i+chunk_size],
|
203 |
+
render_size=render_size,
|
204 |
+
)['img']
|
205 |
+
else:
|
206 |
+
frame = model.synthesizer(
|
207 |
+
planes,
|
208 |
+
cameras=render_cameras[:, i:i+chunk_size],
|
209 |
+
render_size=render_size,
|
210 |
+
)['images_rgb']
|
211 |
+
frames.append(frame)
|
212 |
+
frames = torch.cat(frames, dim=1)
|
213 |
+
|
214 |
+
images_to_video(
|
215 |
+
frames[0],
|
216 |
+
video_fpath,
|
217 |
+
fps=30,
|
218 |
+
)
|
219 |
+
|
220 |
+
print(f"Video saved to {video_fpath}")
|
221 |
+
|
222 |
+
mesh_fpath = make_mesh(mesh_fpath, planes)
|
223 |
+
|
224 |
+
return video_fpath, mesh_fpath
|
225 |
+
|
226 |
+
|
227 |
+
import gradio as gr
|
228 |
+
|
229 |
+
_HEADER_ = '''
|
230 |
+
<h2><b>Official 🤗 Gradio demo for</b>
|
231 |
+
<a href='https://github.com/TencentARC/InstantMesh' target='_blank'>
|
232 |
+
<b>InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models</b>
|
233 |
+
</a>.
|
234 |
+
</h2>
|
235 |
+
'''
|
236 |
+
|
237 |
+
_LINKS_ = '''
|
238 |
+
<h3>Code is available at <a href='https://github.com/TencentARC/InstantMesh' target='_blank'>GitHub</a></h3>
|
239 |
+
<h3>Report is available at <a href='https://arxiv.org/abs/2404.07191' target='_blank'>ArXiv</a></h3>
|
240 |
+
'''
|
241 |
+
|
242 |
+
_CITE_ = r"""
|
243 |
+
```bibtex
|
244 |
+
@article{xu2024instantmesh,
|
245 |
+
title={InstantMesh: Efficient 3D Mesh Generation from a Single Image with Sparse-view Large Reconstruction Models},
|
246 |
+
author={Xu, Jiale and Cheng, Weihao and Gao, Yiming and Wang, Xintao and Gao, Shenghua and Shan, Ying},
|
247 |
+
journal={arXiv preprint arXiv:2404.07191},
|
248 |
+
year={2024}
|
249 |
+
}
|
250 |
+
```
|
251 |
+
"""
|
252 |
+
|
253 |
+
|
254 |
+
with gr.Blocks() as demo:
|
255 |
+
gr.Markdown(_HEADER_)
|
256 |
+
with gr.Row(variant="panel"):
|
257 |
+
with gr.Column():
|
258 |
+
with gr.Row():
|
259 |
+
input_image = gr.Image(
|
260 |
+
label="Input Image",
|
261 |
+
image_mode="RGBA",
|
262 |
+
sources="upload",
|
263 |
+
width=256,
|
264 |
+
height=256,
|
265 |
+
type="pil",
|
266 |
+
elem_id="content_image",
|
267 |
+
)
|
268 |
+
processed_image = gr.Image(
|
269 |
+
label="Processed Image",
|
270 |
+
image_mode="RGBA",
|
271 |
+
width=256,
|
272 |
+
height=256,
|
273 |
+
type="pil",
|
274 |
+
interactive=False
|
275 |
+
)
|
276 |
+
with gr.Row():
|
277 |
+
with gr.Group():
|
278 |
+
do_remove_background = gr.Checkbox(
|
279 |
+
label="Remove Background", value=True
|
280 |
+
)
|
281 |
+
sample_seed = gr.Number(value=42, label="Seed (Try a different value if the result is unsatisfying)", precision=0)
|
282 |
+
|
283 |
+
sample_steps = gr.Slider(
|
284 |
+
label="Sample Steps",
|
285 |
+
minimum=30,
|
286 |
+
maximum=75,
|
287 |
+
value=75,
|
288 |
+
step=5
|
289 |
+
)
|
290 |
+
|
291 |
+
with gr.Row():
|
292 |
+
submit = gr.Button("Generate", elem_id="generate", variant="primary")
|
293 |
+
|
294 |
+
with gr.Row(variant="panel"):
|
295 |
+
gr.Examples(
|
296 |
+
examples=[
|
297 |
+
os.path.join("examples", img_name) for img_name in sorted(os.listdir("examples"))
|
298 |
+
],
|
299 |
+
inputs=[input_image],
|
300 |
+
label="Examples",
|
301 |
+
examples_per_page=20
|
302 |
+
)
|
303 |
+
|
304 |
+
with gr.Column():
|
305 |
+
|
306 |
+
with gr.Row():
|
307 |
+
|
308 |
+
with gr.Column():
|
309 |
+
mv_show_images = gr.Image(
|
310 |
+
label="Generated Multi-views",
|
311 |
+
type="pil",
|
312 |
+
width=379,
|
313 |
+
interactive=False
|
314 |
+
)
|
315 |
+
|
316 |
+
with gr.Column():
|
317 |
+
output_video = gr.Video(
|
318 |
+
label="video", format="mp4",
|
319 |
+
width=379,
|
320 |
+
autoplay=True,
|
321 |
+
interactive=False
|
322 |
+
)
|
323 |
+
|
324 |
+
with gr.Row():
|
325 |
+
output_model_obj = gr.Model3D(
|
326 |
+
label="Output Model (OBJ Format)",
|
327 |
+
width=768,
|
328 |
+
interactive=False,
|
329 |
+
)
|
330 |
+
gr.Markdown(_LINKS_)
|
331 |
+
gr.Markdown(_CITE_)
|
332 |
+
|
333 |
+
mv_images = gr.State()
|
334 |
+
|
335 |
+
submit.click(fn=check_input_image, inputs=[input_image]).success(
|
336 |
+
fn=preprocess,
|
337 |
+
inputs=[input_image, do_remove_background],
|
338 |
+
outputs=[processed_image],
|
339 |
+
).success(
|
340 |
+
fn=generate_mvs,
|
341 |
+
inputs=[processed_image, sample_steps, sample_seed],
|
342 |
+
outputs=[mv_images, mv_show_images],
|
343 |
+
).success(
|
344 |
+
fn=make3d,
|
345 |
+
inputs=[mv_images],
|
346 |
+
outputs=[output_video, output_model_obj]
|
347 |
+
)
|
348 |
+
|
349 |
+
demo.launch()
|
configs/instant-mesh-base.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm_mesh.InstantMesh
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 12
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 40
|
13 |
+
rendering_samples_per_ray: 96
|
14 |
+
grid_res: 128
|
15 |
+
grid_scale: 2.1
|
16 |
+
|
17 |
+
|
18 |
+
infer_config:
|
19 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
20 |
+
model_path: ckpts/instant_mesh_base.ckpt
|
21 |
+
texture_resolution: 1024
|
22 |
+
render_resolution: 512
|
configs/instant-mesh-large.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm_mesh.InstantMesh
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 16
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 80
|
13 |
+
rendering_samples_per_ray: 128
|
14 |
+
grid_res: 128
|
15 |
+
grid_scale: 2.1
|
16 |
+
|
17 |
+
|
18 |
+
infer_config:
|
19 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
20 |
+
model_path: ckpts/instant_mesh_large.ckpt
|
21 |
+
texture_resolution: 1024
|
22 |
+
render_resolution: 512
|
configs/instant-nerf-base.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm.InstantNeRF
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 12
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 40
|
13 |
+
rendering_samples_per_ray: 96
|
14 |
+
|
15 |
+
|
16 |
+
infer_config:
|
17 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
18 |
+
model_path: ckpts/instant_nerf_base.ckpt
|
19 |
+
mesh_threshold: 10.0
|
20 |
+
mesh_resolution: 256
|
21 |
+
render_resolution: 384
|
configs/instant-nerf-large.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_config:
|
2 |
+
target: src.models.lrm.InstantNeRF
|
3 |
+
params:
|
4 |
+
encoder_feat_dim: 768
|
5 |
+
encoder_freeze: false
|
6 |
+
encoder_model_name: facebook/dino-vitb16
|
7 |
+
transformer_dim: 1024
|
8 |
+
transformer_layers: 16
|
9 |
+
transformer_heads: 16
|
10 |
+
triplane_low_res: 32
|
11 |
+
triplane_high_res: 64
|
12 |
+
triplane_dim: 80
|
13 |
+
rendering_samples_per_ray: 128
|
14 |
+
|
15 |
+
|
16 |
+
infer_config:
|
17 |
+
unet_path: ckpts/diffusion_pytorch_model.bin
|
18 |
+
model_path: ckpts/instant_nerf_large.ckpt
|
19 |
+
mesh_threshold: 10.0
|
20 |
+
mesh_resolution: 256
|
21 |
+
render_resolution: 384
|
examples/bird.jpg
ADDED
examples/bubble_mart_blue.png
ADDED
examples/cake.jpg
ADDED
examples/cartoon_dinosaur.png
ADDED
examples/cartoon_girl.jpg
ADDED
examples/chair_comfort.jpg
ADDED
examples/chair_wood.jpg
ADDED
examples/chest.jpg
ADDED
examples/cube.png
ADDED
examples/extinguisher.png
ADDED
examples/fruit_bycycle.jpg
ADDED
examples/fruit_elephant.jpg
ADDED
examples/genshin_building.png
ADDED
examples/house2.jpg
ADDED
examples/kunkun.png
ADDED
examples/mushroom_teapot.jpg
ADDED
examples/pikachu.png
ADDED
examples/pistol.png
ADDED
examples/plant.jpg
ADDED
examples/robot.jpg
ADDED
examples/sea_turtle.png
ADDED
examples/skating_shoe.jpg
ADDED
examples/sorting_board.png
ADDED
examples/sword.png
ADDED
examples/toy_car.jpg
ADDED
examples/toyduck.png
ADDED
examples/watermelon.png
ADDED
examples/whitedog.png
ADDED
examples/x_cube.jpg
ADDED
examples/x_teapot.jpg
ADDED
examples/x_toyduck.jpg
ADDED
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch-lightning==2.1.2
|
2 |
+
einops
|
3 |
+
omegaconf
|
4 |
+
deepspeed
|
5 |
+
torchmetrics
|
6 |
+
webdataset
|
7 |
+
accelerate
|
8 |
+
tensorboard
|
9 |
+
PyMCubes
|
10 |
+
trimesh
|
11 |
+
rembg
|
12 |
+
transformers==4.34.1
|
13 |
+
diffusers==0.19.3
|
14 |
+
bitsandbytes
|
15 |
+
imageio[ffmpeg]
|
16 |
+
xatlas
|
17 |
+
plyfile
|
18 |
+
xformers==0.0.22.post7
|
19 |
+
git+https://github.com/NVlabs/nvdiffrast/
|
20 |
+
torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
|
21 |
+
huggingface-hub
|
src/__init__.py
ADDED
File without changes
|
src/data/__init__.py
ADDED
File without changes
|
src/data/objaverse.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import math
|
3 |
+
import json
|
4 |
+
import importlib
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import random
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
import webdataset as wds
|
12 |
+
import pytorch_lightning as pl
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from torch.utils.data.distributed import DistributedSampler
|
19 |
+
from torchvision import transforms
|
20 |
+
|
21 |
+
from src.utils.train_util import instantiate_from_config
|
22 |
+
from src.utils.camera_util import (
|
23 |
+
FOV_to_intrinsics,
|
24 |
+
center_looking_at_camera_pose,
|
25 |
+
get_surrounding_views,
|
26 |
+
)
|
27 |
+
|
28 |
+
|
29 |
+
class DataModuleFromConfig(pl.LightningDataModule):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
batch_size=8,
|
33 |
+
num_workers=4,
|
34 |
+
train=None,
|
35 |
+
validation=None,
|
36 |
+
test=None,
|
37 |
+
**kwargs,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.batch_size = batch_size
|
42 |
+
self.num_workers = num_workers
|
43 |
+
|
44 |
+
self.dataset_configs = dict()
|
45 |
+
if train is not None:
|
46 |
+
self.dataset_configs['train'] = train
|
47 |
+
if validation is not None:
|
48 |
+
self.dataset_configs['validation'] = validation
|
49 |
+
if test is not None:
|
50 |
+
self.dataset_configs['test'] = test
|
51 |
+
|
52 |
+
def setup(self, stage):
|
53 |
+
|
54 |
+
if stage in ['fit']:
|
55 |
+
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
|
56 |
+
else:
|
57 |
+
raise NotImplementedError
|
58 |
+
|
59 |
+
def train_dataloader(self):
|
60 |
+
|
61 |
+
sampler = DistributedSampler(self.datasets['train'])
|
62 |
+
return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
63 |
+
|
64 |
+
def val_dataloader(self):
|
65 |
+
|
66 |
+
sampler = DistributedSampler(self.datasets['validation'])
|
67 |
+
return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler)
|
68 |
+
|
69 |
+
def test_dataloader(self):
|
70 |
+
|
71 |
+
return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False)
|
72 |
+
|
73 |
+
|
74 |
+
class ObjaverseData(Dataset):
|
75 |
+
def __init__(self,
|
76 |
+
root_dir='objaverse/',
|
77 |
+
meta_fname='valid_paths.json',
|
78 |
+
input_image_dir='rendering_random_32views',
|
79 |
+
target_image_dir='rendering_random_32views',
|
80 |
+
input_view_num=6,
|
81 |
+
target_view_num=2,
|
82 |
+
total_view_n=32,
|
83 |
+
fov=50,
|
84 |
+
camera_rotation=True,
|
85 |
+
validation=False,
|
86 |
+
):
|
87 |
+
self.root_dir = Path(root_dir)
|
88 |
+
self.input_image_dir = input_image_dir
|
89 |
+
self.target_image_dir = target_image_dir
|
90 |
+
|
91 |
+
self.input_view_num = input_view_num
|
92 |
+
self.target_view_num = target_view_num
|
93 |
+
self.total_view_n = total_view_n
|
94 |
+
self.fov = fov
|
95 |
+
self.camera_rotation = camera_rotation
|
96 |
+
|
97 |
+
with open(os.path.join(root_dir, meta_fname)) as f:
|
98 |
+
filtered_dict = json.load(f)
|
99 |
+
paths = filtered_dict['good_objs']
|
100 |
+
self.paths = paths
|
101 |
+
|
102 |
+
self.depth_scale = 4.0
|
103 |
+
|
104 |
+
total_objects = len(self.paths)
|
105 |
+
print('============= length of dataset %d =============' % len(self.paths))
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
return len(self.paths)
|
109 |
+
|
110 |
+
def load_im(self, path, color):
|
111 |
+
'''
|
112 |
+
replace background pixel with random color in rendering
|
113 |
+
'''
|
114 |
+
pil_img = Image.open(path)
|
115 |
+
|
116 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
117 |
+
alpha = image[:, :, 3:]
|
118 |
+
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
119 |
+
|
120 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
121 |
+
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
122 |
+
return image, alpha
|
123 |
+
|
124 |
+
def __getitem__(self, index):
|
125 |
+
# load data
|
126 |
+
while True:
|
127 |
+
input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index])
|
128 |
+
target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index])
|
129 |
+
|
130 |
+
indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False)
|
131 |
+
input_indices = indices[:self.input_view_num]
|
132 |
+
target_indices = indices[self.input_view_num:]
|
133 |
+
|
134 |
+
'''background color, default: white'''
|
135 |
+
bg_white = [1., 1., 1.]
|
136 |
+
bg_black = [0., 0., 0.]
|
137 |
+
|
138 |
+
image_list = []
|
139 |
+
alpha_list = []
|
140 |
+
depth_list = []
|
141 |
+
normal_list = []
|
142 |
+
pose_list = []
|
143 |
+
|
144 |
+
try:
|
145 |
+
input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses']
|
146 |
+
for idx in input_indices:
|
147 |
+
image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white)
|
148 |
+
normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black)
|
149 |
+
depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
|
150 |
+
depth = torch.from_numpy(depth).unsqueeze(0)
|
151 |
+
pose = input_cameras[idx]
|
152 |
+
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
|
153 |
+
|
154 |
+
image_list.append(image)
|
155 |
+
alpha_list.append(alpha)
|
156 |
+
depth_list.append(depth)
|
157 |
+
normal_list.append(normal)
|
158 |
+
pose_list.append(pose)
|
159 |
+
|
160 |
+
target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses']
|
161 |
+
for idx in target_indices:
|
162 |
+
image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white)
|
163 |
+
normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black)
|
164 |
+
depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale
|
165 |
+
depth = torch.from_numpy(depth).unsqueeze(0)
|
166 |
+
pose = target_cameras[idx]
|
167 |
+
pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0)
|
168 |
+
|
169 |
+
image_list.append(image)
|
170 |
+
alpha_list.append(alpha)
|
171 |
+
depth_list.append(depth)
|
172 |
+
normal_list.append(normal)
|
173 |
+
pose_list.append(pose)
|
174 |
+
|
175 |
+
except Exception as e:
|
176 |
+
print(e)
|
177 |
+
index = np.random.randint(0, len(self.paths))
|
178 |
+
continue
|
179 |
+
|
180 |
+
break
|
181 |
+
|
182 |
+
images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
|
183 |
+
alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
|
184 |
+
depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W)
|
185 |
+
normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W)
|
186 |
+
w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4)
|
187 |
+
c2ws = torch.linalg.inv(w2cs).float()
|
188 |
+
|
189 |
+
normals = normals * 2.0 - 1.0
|
190 |
+
normals = F.normalize(normals, dim=1)
|
191 |
+
normals = (normals + 1.0) / 2.0
|
192 |
+
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
|
193 |
+
|
194 |
+
# random rotation along z axis
|
195 |
+
if self.camera_rotation:
|
196 |
+
degree = np.random.uniform(0, math.pi * 2)
|
197 |
+
rot = torch.tensor([
|
198 |
+
[np.cos(degree), -np.sin(degree), 0, 0],
|
199 |
+
[np.sin(degree), np.cos(degree), 0, 0],
|
200 |
+
[0, 0, 1, 0],
|
201 |
+
[0, 0, 0, 1],
|
202 |
+
]).unsqueeze(0).float()
|
203 |
+
c2ws = torch.matmul(rot, c2ws)
|
204 |
+
|
205 |
+
# rotate normals
|
206 |
+
N, _, H, W = normals.shape
|
207 |
+
normals = normals * 2.0 - 1.0
|
208 |
+
normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W)
|
209 |
+
normals = F.normalize(normals, dim=1)
|
210 |
+
normals = (normals + 1.0) / 2.0
|
211 |
+
normals = torch.lerp(torch.zeros_like(normals), normals, alphas)
|
212 |
+
|
213 |
+
# random scaling
|
214 |
+
if np.random.rand() < 0.5:
|
215 |
+
scale = np.random.uniform(0.8, 1.0)
|
216 |
+
c2ws[:, :3, 3] *= scale
|
217 |
+
depths *= scale
|
218 |
+
|
219 |
+
# instrinsics of perspective cameras
|
220 |
+
K = FOV_to_intrinsics(self.fov)
|
221 |
+
Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float()
|
222 |
+
|
223 |
+
data = {
|
224 |
+
'input_images': images[:self.input_view_num], # (6, 3, H, W)
|
225 |
+
'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W)
|
226 |
+
'input_depths': depths[:self.input_view_num], # (6, 1, H, W)
|
227 |
+
'input_normals': normals[:self.input_view_num], # (6, 3, H, W)
|
228 |
+
'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4)
|
229 |
+
'input_Ks': Ks[:self.input_view_num], # (6, 3, 3)
|
230 |
+
|
231 |
+
# lrm generator input and supervision
|
232 |
+
'target_images': images[self.input_view_num:], # (V, 3, H, W)
|
233 |
+
'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W)
|
234 |
+
'target_depths': depths[self.input_view_num:], # (V, 1, H, W)
|
235 |
+
'target_normals': normals[self.input_view_num:], # (V, 3, H, W)
|
236 |
+
'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4)
|
237 |
+
'target_Ks': Ks[self.input_view_num:], # (V, 3, 3)
|
238 |
+
|
239 |
+
'depth_available': 1,
|
240 |
+
}
|
241 |
+
return data
|
242 |
+
|
243 |
+
|
244 |
+
class ValidationData(Dataset):
|
245 |
+
def __init__(self,
|
246 |
+
root_dir='objaverse/',
|
247 |
+
input_view_num=6,
|
248 |
+
input_image_size=256,
|
249 |
+
fov=50,
|
250 |
+
):
|
251 |
+
self.root_dir = Path(root_dir)
|
252 |
+
self.input_view_num = input_view_num
|
253 |
+
self.input_image_size = input_image_size
|
254 |
+
self.fov = fov
|
255 |
+
|
256 |
+
self.paths = sorted(os.listdir(self.root_dir))
|
257 |
+
print('============= length of dataset %d =============' % len(self.paths))
|
258 |
+
|
259 |
+
cam_distance = 2.5
|
260 |
+
azimuths = np.array([30, 90, 150, 210, 270, 330])
|
261 |
+
elevations = np.array([30, -20, 30, -20, 30, -20])
|
262 |
+
azimuths = np.deg2rad(azimuths)
|
263 |
+
elevations = np.deg2rad(elevations)
|
264 |
+
|
265 |
+
x = cam_distance * np.cos(elevations) * np.cos(azimuths)
|
266 |
+
y = cam_distance * np.cos(elevations) * np.sin(azimuths)
|
267 |
+
z = cam_distance * np.sin(elevations)
|
268 |
+
|
269 |
+
cam_locations = np.stack([x, y, z], axis=-1)
|
270 |
+
cam_locations = torch.from_numpy(cam_locations).float()
|
271 |
+
c2ws = center_looking_at_camera_pose(cam_locations)
|
272 |
+
self.c2ws = c2ws.float()
|
273 |
+
self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float()
|
274 |
+
|
275 |
+
render_c2ws = get_surrounding_views(M=8, radius=cam_distance)
|
276 |
+
render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1)
|
277 |
+
self.render_c2ws = render_c2ws.float()
|
278 |
+
self.render_Ks = render_Ks.float()
|
279 |
+
|
280 |
+
def __len__(self):
|
281 |
+
return len(self.paths)
|
282 |
+
|
283 |
+
def load_im(self, path, color):
|
284 |
+
'''
|
285 |
+
replace background pixel with random color in rendering
|
286 |
+
'''
|
287 |
+
pil_img = Image.open(path)
|
288 |
+
pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC)
|
289 |
+
|
290 |
+
image = np.asarray(pil_img, dtype=np.float32) / 255.
|
291 |
+
if image.shape[-1] == 4:
|
292 |
+
alpha = image[:, :, 3:]
|
293 |
+
image = image[:, :, :3] * alpha + color * (1 - alpha)
|
294 |
+
else:
|
295 |
+
alpha = np.ones_like(image[:, :, :1])
|
296 |
+
|
297 |
+
image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
|
298 |
+
alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float()
|
299 |
+
return image, alpha
|
300 |
+
|
301 |
+
def __getitem__(self, index):
|
302 |
+
# load data
|
303 |
+
input_image_path = os.path.join(self.root_dir, self.paths[index])
|
304 |
+
|
305 |
+
'''background color, default: white'''
|
306 |
+
# color = np.random.uniform(0.48, 0.52)
|
307 |
+
bkg_color = [1.0, 1.0, 1.0]
|
308 |
+
|
309 |
+
image_list = []
|
310 |
+
alpha_list = []
|
311 |
+
|
312 |
+
for idx in range(self.input_view_num):
|
313 |
+
image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color)
|
314 |
+
image_list.append(image)
|
315 |
+
alpha_list.append(alpha)
|
316 |
+
|
317 |
+
images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W)
|
318 |
+
alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W)
|
319 |
+
|
320 |
+
data = {
|
321 |
+
'input_images': images, # (6, 3, H, W)
|
322 |
+
'input_alphas': alphas, # (6, 1, H, W)
|
323 |
+
'input_c2ws': self.c2ws, # (6, 4, 4)
|
324 |
+
'input_Ks': self.Ks, # (6, 3, 3)
|
325 |
+
|
326 |
+
'render_c2ws': self.render_c2ws,
|
327 |
+
'render_Ks': self.render_Ks,
|
328 |
+
}
|
329 |
+
return data
|
src/model.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.transforms import v2
|
6 |
+
from torchvision.utils import make_grid, save_image
|
7 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from src.utils.train_util import instantiate_from_config
|
12 |
+
|
13 |
+
|
14 |
+
class MVRecon(pl.LightningModule):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
lrm_generator_config,
|
18 |
+
lrm_path=None,
|
19 |
+
input_size=256,
|
20 |
+
render_size=192,
|
21 |
+
):
|
22 |
+
super(MVRecon, self).__init__()
|
23 |
+
|
24 |
+
self.input_size = input_size
|
25 |
+
self.render_size = render_size
|
26 |
+
|
27 |
+
# init modules
|
28 |
+
self.lrm_generator = instantiate_from_config(lrm_generator_config)
|
29 |
+
if lrm_path is not None:
|
30 |
+
lrm_ckpt = torch.load(lrm_path)
|
31 |
+
self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False)
|
32 |
+
|
33 |
+
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
|
34 |
+
|
35 |
+
self.validation_step_outputs = []
|
36 |
+
|
37 |
+
def on_fit_start(self):
|
38 |
+
if self.global_rank == 0:
|
39 |
+
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
|
40 |
+
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
|
41 |
+
|
42 |
+
def prepare_batch_data(self, batch):
|
43 |
+
lrm_generator_input = {}
|
44 |
+
render_gt = {} # for supervision
|
45 |
+
|
46 |
+
# input images
|
47 |
+
images = batch['input_images']
|
48 |
+
images = v2.functional.resize(
|
49 |
+
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
50 |
+
|
51 |
+
lrm_generator_input['images'] = images.to(self.device)
|
52 |
+
|
53 |
+
# input cameras and render cameras
|
54 |
+
input_c2ws = batch['input_c2ws'].flatten(-2)
|
55 |
+
input_Ks = batch['input_Ks'].flatten(-2)
|
56 |
+
target_c2ws = batch['target_c2ws'].flatten(-2)
|
57 |
+
target_Ks = batch['target_Ks'].flatten(-2)
|
58 |
+
render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1)
|
59 |
+
render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1)
|
60 |
+
render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1)
|
61 |
+
|
62 |
+
input_extrinsics = input_c2ws[:, :, :12]
|
63 |
+
input_intrinsics = torch.stack([
|
64 |
+
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
65 |
+
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
66 |
+
], dim=-1)
|
67 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
68 |
+
|
69 |
+
# add noise to input cameras
|
70 |
+
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
|
71 |
+
|
72 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
73 |
+
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
|
74 |
+
|
75 |
+
# target images
|
76 |
+
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
|
77 |
+
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
|
78 |
+
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
|
79 |
+
|
80 |
+
# random crop
|
81 |
+
render_size = np.random.randint(self.render_size, 513)
|
82 |
+
target_images = v2.functional.resize(
|
83 |
+
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
|
84 |
+
target_depths = v2.functional.resize(
|
85 |
+
target_depths, render_size, interpolation=0, antialias=True)
|
86 |
+
target_alphas = v2.functional.resize(
|
87 |
+
target_alphas, render_size, interpolation=0, antialias=True)
|
88 |
+
|
89 |
+
crop_params = v2.RandomCrop.get_params(
|
90 |
+
target_images, output_size=(self.render_size, self.render_size))
|
91 |
+
target_images = v2.functional.crop(target_images, *crop_params)
|
92 |
+
target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1]
|
93 |
+
target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1]
|
94 |
+
|
95 |
+
lrm_generator_input['render_size'] = render_size
|
96 |
+
lrm_generator_input['crop_params'] = crop_params
|
97 |
+
|
98 |
+
render_gt['target_images'] = target_images.to(self.device)
|
99 |
+
render_gt['target_depths'] = target_depths.to(self.device)
|
100 |
+
render_gt['target_alphas'] = target_alphas.to(self.device)
|
101 |
+
|
102 |
+
return lrm_generator_input, render_gt
|
103 |
+
|
104 |
+
def prepare_validation_batch_data(self, batch):
|
105 |
+
lrm_generator_input = {}
|
106 |
+
|
107 |
+
# input images
|
108 |
+
images = batch['input_images']
|
109 |
+
images = v2.functional.resize(
|
110 |
+
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
111 |
+
|
112 |
+
lrm_generator_input['images'] = images.to(self.device)
|
113 |
+
|
114 |
+
input_c2ws = batch['input_c2ws'].flatten(-2)
|
115 |
+
input_Ks = batch['input_Ks'].flatten(-2)
|
116 |
+
|
117 |
+
input_extrinsics = input_c2ws[:, :, :12]
|
118 |
+
input_intrinsics = torch.stack([
|
119 |
+
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
120 |
+
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
121 |
+
], dim=-1)
|
122 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
123 |
+
|
124 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
125 |
+
|
126 |
+
render_c2ws = batch['render_c2ws'].flatten(-2)
|
127 |
+
render_Ks = batch['render_Ks'].flatten(-2)
|
128 |
+
render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1)
|
129 |
+
|
130 |
+
lrm_generator_input['render_cameras'] = render_cameras.to(self.device)
|
131 |
+
lrm_generator_input['render_size'] = 384
|
132 |
+
lrm_generator_input['crop_params'] = None
|
133 |
+
|
134 |
+
return lrm_generator_input
|
135 |
+
|
136 |
+
def forward_lrm_generator(
|
137 |
+
self,
|
138 |
+
images,
|
139 |
+
cameras,
|
140 |
+
render_cameras,
|
141 |
+
render_size=192,
|
142 |
+
crop_params=None,
|
143 |
+
chunk_size=1,
|
144 |
+
):
|
145 |
+
planes = torch.utils.checkpoint.checkpoint(
|
146 |
+
self.lrm_generator.forward_planes,
|
147 |
+
images,
|
148 |
+
cameras,
|
149 |
+
use_reentrant=False,
|
150 |
+
)
|
151 |
+
frames = []
|
152 |
+
for i in range(0, render_cameras.shape[1], chunk_size):
|
153 |
+
frames.append(
|
154 |
+
torch.utils.checkpoint.checkpoint(
|
155 |
+
self.lrm_generator.synthesizer,
|
156 |
+
planes,
|
157 |
+
cameras=render_cameras[:, i:i+chunk_size],
|
158 |
+
render_size=render_size,
|
159 |
+
crop_params=crop_params,
|
160 |
+
use_reentrant=False
|
161 |
+
)
|
162 |
+
)
|
163 |
+
frames = {
|
164 |
+
k: torch.cat([r[k] for r in frames], dim=1)
|
165 |
+
for k in frames[0].keys()
|
166 |
+
}
|
167 |
+
return frames
|
168 |
+
|
169 |
+
def forward(self, lrm_generator_input):
|
170 |
+
images = lrm_generator_input['images']
|
171 |
+
cameras = lrm_generator_input['cameras']
|
172 |
+
render_cameras = lrm_generator_input['render_cameras']
|
173 |
+
render_size = lrm_generator_input['render_size']
|
174 |
+
crop_params = lrm_generator_input['crop_params']
|
175 |
+
|
176 |
+
out = self.forward_lrm_generator(
|
177 |
+
images,
|
178 |
+
cameras,
|
179 |
+
render_cameras,
|
180 |
+
render_size=render_size,
|
181 |
+
crop_params=crop_params,
|
182 |
+
chunk_size=1,
|
183 |
+
)
|
184 |
+
render_images = torch.clamp(out['images_rgb'], 0.0, 1.0)
|
185 |
+
render_depths = out['images_depth']
|
186 |
+
render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0)
|
187 |
+
|
188 |
+
out = {
|
189 |
+
'render_images': render_images,
|
190 |
+
'render_depths': render_depths,
|
191 |
+
'render_alphas': render_alphas,
|
192 |
+
}
|
193 |
+
return out
|
194 |
+
|
195 |
+
def training_step(self, batch, batch_idx):
|
196 |
+
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
|
197 |
+
|
198 |
+
render_out = self.forward(lrm_generator_input)
|
199 |
+
|
200 |
+
loss, loss_dict = self.compute_loss(render_out, render_gt)
|
201 |
+
|
202 |
+
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
203 |
+
|
204 |
+
if self.global_step % 1000 == 0 and self.global_rank == 0:
|
205 |
+
B, N, C, H, W = render_gt['target_images'].shape
|
206 |
+
N_in = lrm_generator_input['images'].shape[1]
|
207 |
+
|
208 |
+
input_images = v2.functional.resize(
|
209 |
+
lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1)
|
210 |
+
input_images = torch.cat(
|
211 |
+
[input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1)
|
212 |
+
|
213 |
+
input_images = rearrange(
|
214 |
+
input_images, 'b n c h w -> b c h (n w)')
|
215 |
+
target_images = rearrange(
|
216 |
+
render_gt['target_images'], 'b n c h w -> b c h (n w)')
|
217 |
+
render_images = rearrange(
|
218 |
+
render_out['render_images'], 'b n c h w -> b c h (n w)')
|
219 |
+
target_alphas = rearrange(
|
220 |
+
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
221 |
+
render_alphas = rearrange(
|
222 |
+
repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
223 |
+
target_depths = rearrange(
|
224 |
+
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
225 |
+
render_depths = rearrange(
|
226 |
+
repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
227 |
+
MAX_DEPTH = torch.max(target_depths)
|
228 |
+
target_depths = target_depths / MAX_DEPTH * target_alphas
|
229 |
+
render_depths = render_depths / MAX_DEPTH
|
230 |
+
|
231 |
+
grid = torch.cat([
|
232 |
+
input_images,
|
233 |
+
target_images, render_images,
|
234 |
+
target_alphas, render_alphas,
|
235 |
+
target_depths, render_depths,
|
236 |
+
], dim=-2)
|
237 |
+
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
|
238 |
+
|
239 |
+
save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png'))
|
240 |
+
|
241 |
+
return loss
|
242 |
+
|
243 |
+
def compute_loss(self, render_out, render_gt):
|
244 |
+
# NOTE: the rgb value range of OpenLRM is [0, 1]
|
245 |
+
render_images = render_out['render_images']
|
246 |
+
target_images = render_gt['target_images'].to(render_images)
|
247 |
+
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
248 |
+
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
249 |
+
|
250 |
+
loss_mse = F.mse_loss(render_images, target_images)
|
251 |
+
loss_lpips = 2.0 * self.lpips(render_images, target_images)
|
252 |
+
|
253 |
+
render_alphas = render_out['render_alphas']
|
254 |
+
target_alphas = render_gt['target_alphas']
|
255 |
+
loss_mask = F.mse_loss(render_alphas, target_alphas)
|
256 |
+
|
257 |
+
loss = loss_mse + loss_lpips + loss_mask
|
258 |
+
|
259 |
+
prefix = 'train'
|
260 |
+
loss_dict = {}
|
261 |
+
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
|
262 |
+
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
|
263 |
+
loss_dict.update({f'{prefix}/loss_mask': loss_mask})
|
264 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
265 |
+
|
266 |
+
return loss, loss_dict
|
267 |
+
|
268 |
+
@torch.no_grad()
|
269 |
+
def validation_step(self, batch, batch_idx):
|
270 |
+
lrm_generator_input = self.prepare_validation_batch_data(batch)
|
271 |
+
|
272 |
+
render_out = self.forward(lrm_generator_input)
|
273 |
+
render_images = render_out['render_images']
|
274 |
+
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
|
275 |
+
|
276 |
+
self.validation_step_outputs.append(render_images)
|
277 |
+
|
278 |
+
def on_validation_epoch_end(self):
|
279 |
+
images = torch.cat(self.validation_step_outputs, dim=-1)
|
280 |
+
|
281 |
+
all_images = self.all_gather(images)
|
282 |
+
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
|
283 |
+
|
284 |
+
if self.global_rank == 0:
|
285 |
+
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
|
286 |
+
|
287 |
+
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
|
288 |
+
save_image(grid, image_path)
|
289 |
+
print(f"Saved image to {image_path}")
|
290 |
+
|
291 |
+
self.validation_step_outputs.clear()
|
292 |
+
|
293 |
+
def configure_optimizers(self):
|
294 |
+
lr = self.learning_rate
|
295 |
+
|
296 |
+
params = []
|
297 |
+
|
298 |
+
lrm_params_fast, lrm_params_slow = [], []
|
299 |
+
for n, p in self.lrm_generator.named_parameters():
|
300 |
+
if 'adaLN_modulation' in n or 'camera_embedder' in n:
|
301 |
+
lrm_params_fast.append(p)
|
302 |
+
else:
|
303 |
+
lrm_params_slow.append(p)
|
304 |
+
params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 })
|
305 |
+
params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 })
|
306 |
+
|
307 |
+
optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95))
|
308 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4)
|
309 |
+
|
310 |
+
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
src/model_mesh.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.transforms import v2
|
6 |
+
from torchvision.utils import make_grid, save_image
|
7 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from einops import rearrange, repeat
|
10 |
+
|
11 |
+
from src.utils.train_util import instantiate_from_config
|
12 |
+
|
13 |
+
|
14 |
+
# Regulrarization loss for FlexiCubes
|
15 |
+
def sdf_reg_loss_batch(sdf, all_edges):
|
16 |
+
sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2)
|
17 |
+
mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1])
|
18 |
+
sdf_f1x6x2 = sdf_f1x6x2[mask]
|
19 |
+
sdf_diff = F.binary_cross_entropy_with_logits(
|
20 |
+
sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \
|
21 |
+
F.binary_cross_entropy_with_logits(
|
22 |
+
sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float())
|
23 |
+
return sdf_diff
|
24 |
+
|
25 |
+
|
26 |
+
class MVRecon(pl.LightningModule):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
lrm_generator_config,
|
30 |
+
input_size=256,
|
31 |
+
render_size=512,
|
32 |
+
init_ckpt=None,
|
33 |
+
):
|
34 |
+
super(MVRecon, self).__init__()
|
35 |
+
|
36 |
+
self.input_size = input_size
|
37 |
+
self.render_size = render_size
|
38 |
+
|
39 |
+
# init modules
|
40 |
+
self.lrm_generator = instantiate_from_config(lrm_generator_config)
|
41 |
+
|
42 |
+
self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg')
|
43 |
+
|
44 |
+
# Load weights from pretrained MVRecon model, and use the mlp
|
45 |
+
# weights to initialize the weights of sdf and rgb mlps.
|
46 |
+
if init_ckpt is not None:
|
47 |
+
sd = torch.load(init_ckpt, map_location='cpu')['state_dict']
|
48 |
+
sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')}
|
49 |
+
sd_fc = {}
|
50 |
+
for k, v in sd.items():
|
51 |
+
if k.startswith('lrm_generator.synthesizer.decoder.net.'):
|
52 |
+
if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer
|
53 |
+
# Here we assume the density filed's isosurface threshold is t,
|
54 |
+
# we reverse the sign of density filed to initialize SDF field.
|
55 |
+
# -(w*x + b - t) = (-w)*x + (t - b)
|
56 |
+
if 'weight' in k:
|
57 |
+
sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1]
|
58 |
+
else:
|
59 |
+
sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1]
|
60 |
+
sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4]
|
61 |
+
else:
|
62 |
+
sd_fc[k.replace('net.', 'net_sdf.')] = v
|
63 |
+
sd_fc[k.replace('net.', 'net_rgb.')] = v
|
64 |
+
else:
|
65 |
+
sd_fc[k] = v
|
66 |
+
sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()}
|
67 |
+
# missing `net_deformation` and `net_weight` parameters
|
68 |
+
self.lrm_generator.load_state_dict(sd_fc, strict=False)
|
69 |
+
print(f'Loaded weights from {init_ckpt}')
|
70 |
+
|
71 |
+
self.validation_step_outputs = []
|
72 |
+
|
73 |
+
def on_fit_start(self):
|
74 |
+
device = torch.device(f'cuda:{self.global_rank}')
|
75 |
+
self.lrm_generator.init_flexicubes_geometry(device)
|
76 |
+
if self.global_rank == 0:
|
77 |
+
os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True)
|
78 |
+
os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True)
|
79 |
+
|
80 |
+
def prepare_batch_data(self, batch):
|
81 |
+
lrm_generator_input = {}
|
82 |
+
render_gt = {}
|
83 |
+
|
84 |
+
# input images
|
85 |
+
images = batch['input_images']
|
86 |
+
images = v2.functional.resize(
|
87 |
+
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
88 |
+
|
89 |
+
lrm_generator_input['images'] = images.to(self.device)
|
90 |
+
|
91 |
+
# input cameras and render cameras
|
92 |
+
input_c2ws = batch['input_c2ws']
|
93 |
+
input_Ks = batch['input_Ks']
|
94 |
+
target_c2ws = batch['target_c2ws']
|
95 |
+
|
96 |
+
render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1)
|
97 |
+
render_w2cs = torch.linalg.inv(render_c2ws)
|
98 |
+
|
99 |
+
input_extrinsics = input_c2ws.flatten(-2)
|
100 |
+
input_extrinsics = input_extrinsics[:, :, :12]
|
101 |
+
input_intrinsics = input_Ks.flatten(-2)
|
102 |
+
input_intrinsics = torch.stack([
|
103 |
+
input_intrinsics[:, :, 0], input_intrinsics[:, :, 4],
|
104 |
+
input_intrinsics[:, :, 2], input_intrinsics[:, :, 5],
|
105 |
+
], dim=-1)
|
106 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
107 |
+
|
108 |
+
# add noise to input_cameras
|
109 |
+
cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02
|
110 |
+
|
111 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
112 |
+
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
|
113 |
+
|
114 |
+
# target images
|
115 |
+
target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1)
|
116 |
+
target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1)
|
117 |
+
target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1)
|
118 |
+
target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1)
|
119 |
+
|
120 |
+
render_size = self.render_size
|
121 |
+
target_images = v2.functional.resize(
|
122 |
+
target_images, render_size, interpolation=3, antialias=True).clamp(0, 1)
|
123 |
+
target_depths = v2.functional.resize(
|
124 |
+
target_depths, render_size, interpolation=0, antialias=True)
|
125 |
+
target_alphas = v2.functional.resize(
|
126 |
+
target_alphas, render_size, interpolation=0, antialias=True)
|
127 |
+
target_normals = v2.functional.resize(
|
128 |
+
target_normals, render_size, interpolation=3, antialias=True)
|
129 |
+
|
130 |
+
lrm_generator_input['render_size'] = render_size
|
131 |
+
|
132 |
+
render_gt['target_images'] = target_images.to(self.device)
|
133 |
+
render_gt['target_depths'] = target_depths.to(self.device)
|
134 |
+
render_gt['target_alphas'] = target_alphas.to(self.device)
|
135 |
+
render_gt['target_normals'] = target_normals.to(self.device)
|
136 |
+
|
137 |
+
return lrm_generator_input, render_gt
|
138 |
+
|
139 |
+
def prepare_validation_batch_data(self, batch):
|
140 |
+
lrm_generator_input = {}
|
141 |
+
|
142 |
+
# input images
|
143 |
+
images = batch['input_images']
|
144 |
+
images = v2.functional.resize(
|
145 |
+
images, self.input_size, interpolation=3, antialias=True).clamp(0, 1)
|
146 |
+
|
147 |
+
lrm_generator_input['images'] = images.to(self.device)
|
148 |
+
|
149 |
+
# input cameras
|
150 |
+
input_c2ws = batch['input_c2ws'].flatten(-2)
|
151 |
+
input_Ks = batch['input_Ks'].flatten(-2)
|
152 |
+
|
153 |
+
input_extrinsics = input_c2ws[:, :, :12]
|
154 |
+
input_intrinsics = torch.stack([
|
155 |
+
input_Ks[:, :, 0], input_Ks[:, :, 4],
|
156 |
+
input_Ks[:, :, 2], input_Ks[:, :, 5],
|
157 |
+
], dim=-1)
|
158 |
+
cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1)
|
159 |
+
|
160 |
+
lrm_generator_input['cameras'] = cameras.to(self.device)
|
161 |
+
|
162 |
+
# render cameras
|
163 |
+
render_c2ws = batch['render_c2ws']
|
164 |
+
render_w2cs = torch.linalg.inv(render_c2ws)
|
165 |
+
|
166 |
+
lrm_generator_input['render_cameras'] = render_w2cs.to(self.device)
|
167 |
+
lrm_generator_input['render_size'] = 384
|
168 |
+
|
169 |
+
return lrm_generator_input
|
170 |
+
|
171 |
+
def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512):
|
172 |
+
planes = torch.utils.checkpoint.checkpoint(
|
173 |
+
self.lrm_generator.forward_planes,
|
174 |
+
images,
|
175 |
+
cameras,
|
176 |
+
use_reentrant=False,
|
177 |
+
)
|
178 |
+
out = self.lrm_generator.forward_geometry(
|
179 |
+
planes,
|
180 |
+
render_cameras,
|
181 |
+
render_size,
|
182 |
+
)
|
183 |
+
return out
|
184 |
+
|
185 |
+
def forward(self, lrm_generator_input):
|
186 |
+
images = lrm_generator_input['images']
|
187 |
+
cameras = lrm_generator_input['cameras']
|
188 |
+
render_cameras = lrm_generator_input['render_cameras']
|
189 |
+
render_size = lrm_generator_input['render_size']
|
190 |
+
|
191 |
+
out = self.forward_lrm_generator(
|
192 |
+
images, cameras, render_cameras, render_size=render_size)
|
193 |
+
|
194 |
+
return out
|
195 |
+
|
196 |
+
def training_step(self, batch, batch_idx):
|
197 |
+
lrm_generator_input, render_gt = self.prepare_batch_data(batch)
|
198 |
+
|
199 |
+
render_out = self.forward(lrm_generator_input)
|
200 |
+
|
201 |
+
loss, loss_dict = self.compute_loss(render_out, render_gt)
|
202 |
+
|
203 |
+
self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
204 |
+
|
205 |
+
if self.global_step % 1000 == 0 and self.global_rank == 0:
|
206 |
+
B, N, C, H, W = render_gt['target_images'].shape
|
207 |
+
N_in = lrm_generator_input['images'].shape[1]
|
208 |
+
|
209 |
+
target_images = rearrange(
|
210 |
+
render_gt['target_images'], 'b n c h w -> b c h (n w)')
|
211 |
+
render_images = rearrange(
|
212 |
+
render_out['img'], 'b n c h w -> b c h (n w)')
|
213 |
+
target_alphas = rearrange(
|
214 |
+
repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
215 |
+
render_alphas = rearrange(
|
216 |
+
repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
217 |
+
target_depths = rearrange(
|
218 |
+
repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
219 |
+
render_depths = rearrange(
|
220 |
+
repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)')
|
221 |
+
target_normals = rearrange(
|
222 |
+
render_gt['target_normals'], 'b n c h w -> b c h (n w)')
|
223 |
+
render_normals = rearrange(
|
224 |
+
render_out['normal'], 'b n c h w -> b c h (n w)')
|
225 |
+
MAX_DEPTH = torch.max(target_depths)
|
226 |
+
target_depths = target_depths / MAX_DEPTH * target_alphas
|
227 |
+
render_depths = render_depths / MAX_DEPTH
|
228 |
+
|
229 |
+
grid = torch.cat([
|
230 |
+
target_images, render_images,
|
231 |
+
target_alphas, render_alphas,
|
232 |
+
target_depths, render_depths,
|
233 |
+
target_normals, render_normals,
|
234 |
+
], dim=-2)
|
235 |
+
grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1))
|
236 |
+
|
237 |
+
image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')
|
238 |
+
save_image(grid, image_path)
|
239 |
+
print(f"Saved image to {image_path}")
|
240 |
+
|
241 |
+
return loss
|
242 |
+
|
243 |
+
def compute_loss(self, render_out, render_gt):
|
244 |
+
# NOTE: the rgb value range of OpenLRM is [0, 1]
|
245 |
+
render_images = render_out['img']
|
246 |
+
target_images = render_gt['target_images'].to(render_images)
|
247 |
+
render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
248 |
+
target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0
|
249 |
+
loss_mse = F.mse_loss(render_images, target_images)
|
250 |
+
loss_lpips = 2.0 * self.lpips(render_images, target_images)
|
251 |
+
|
252 |
+
render_alphas = render_out['mask']
|
253 |
+
target_alphas = render_gt['target_alphas']
|
254 |
+
loss_mask = F.mse_loss(render_alphas, target_alphas)
|
255 |
+
|
256 |
+
render_depths = render_out['depth']
|
257 |
+
target_depths = render_gt['target_depths']
|
258 |
+
loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0])
|
259 |
+
|
260 |
+
render_normals = render_out['normal'] * 2.0 - 1.0
|
261 |
+
target_normals = render_gt['target_normals'] * 2.0 - 1.0
|
262 |
+
similarity = (render_normals * target_normals).sum(dim=-3).abs()
|
263 |
+
normal_mask = target_alphas.squeeze(-3)
|
264 |
+
loss_normal = 1 - similarity[normal_mask>0].mean()
|
265 |
+
loss_normal = 0.2 * loss_normal
|
266 |
+
|
267 |
+
# flexicubes regularization loss
|
268 |
+
sdf = render_out['sdf']
|
269 |
+
sdf_reg_loss = render_out['sdf_reg_loss']
|
270 |
+
sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01
|
271 |
+
_, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss
|
272 |
+
flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5
|
273 |
+
flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1
|
274 |
+
|
275 |
+
loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg
|
276 |
+
|
277 |
+
loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg
|
278 |
+
|
279 |
+
prefix = 'train'
|
280 |
+
loss_dict = {}
|
281 |
+
loss_dict.update({f'{prefix}/loss_mse': loss_mse})
|
282 |
+
loss_dict.update({f'{prefix}/loss_lpips': loss_lpips})
|
283 |
+
loss_dict.update({f'{prefix}/loss_mask': loss_mask})
|
284 |
+
loss_dict.update({f'{prefix}/loss_normal': loss_normal})
|
285 |
+
loss_dict.update({f'{prefix}/loss_depth': loss_depth})
|
286 |
+
loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy})
|
287 |
+
loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg})
|
288 |
+
loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg})
|
289 |
+
loss_dict.update({f'{prefix}/loss': loss})
|
290 |
+
|
291 |
+
return loss, loss_dict
|
292 |
+
|
293 |
+
@torch.no_grad()
|
294 |
+
def validation_step(self, batch, batch_idx):
|
295 |
+
lrm_generator_input = self.prepare_validation_batch_data(batch)
|
296 |
+
|
297 |
+
render_out = self.forward(lrm_generator_input)
|
298 |
+
render_images = render_out['img']
|
299 |
+
render_images = rearrange(render_images, 'b n c h w -> b c h (n w)')
|
300 |
+
|
301 |
+
self.validation_step_outputs.append(render_images)
|
302 |
+
|
303 |
+
def on_validation_epoch_end(self):
|
304 |
+
images = torch.cat(self.validation_step_outputs, dim=-1)
|
305 |
+
|
306 |
+
all_images = self.all_gather(images)
|
307 |
+
all_images = rearrange(all_images, 'r b c h w -> (r b) c h w')
|
308 |
+
|
309 |
+
if self.global_rank == 0:
|
310 |
+
image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png')
|
311 |
+
|
312 |
+
grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1))
|
313 |
+
save_image(grid, image_path)
|
314 |
+
print(f"Saved image to {image_path}")
|
315 |
+
|
316 |
+
self.validation_step_outputs.clear()
|
317 |
+
|
318 |
+
def configure_optimizers(self):
|
319 |
+
lr = self.learning_rate
|
320 |
+
|
321 |
+
optimizer = torch.optim.AdamW(
|
322 |
+
self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01)
|
323 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0)
|
324 |
+
|
325 |
+
return {'optimizer': optimizer, 'lr_scheduler': scheduler}
|
src/models/__init__.py
ADDED
File without changes
|
src/models/decoder/__init__.py
ADDED
File without changes
|
src/models/decoder/transformer.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
|
19 |
+
|
20 |
+
class BasicTransformerBlock(nn.Module):
|
21 |
+
"""
|
22 |
+
Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks.
|
23 |
+
"""
|
24 |
+
# use attention from torch.nn.MultiHeadAttention
|
25 |
+
# Block contains a cross-attention layer, a self-attention layer, and a MLP
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
inner_dim: int,
|
29 |
+
cond_dim: int,
|
30 |
+
num_heads: int,
|
31 |
+
eps: float,
|
32 |
+
attn_drop: float = 0.,
|
33 |
+
attn_bias: bool = False,
|
34 |
+
mlp_ratio: float = 4.,
|
35 |
+
mlp_drop: float = 0.,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
|
39 |
+
self.norm1 = nn.LayerNorm(inner_dim)
|
40 |
+
self.cross_attn = nn.MultiheadAttention(
|
41 |
+
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
42 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
43 |
+
self.norm2 = nn.LayerNorm(inner_dim)
|
44 |
+
self.self_attn = nn.MultiheadAttention(
|
45 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
46 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
47 |
+
self.norm3 = nn.LayerNorm(inner_dim)
|
48 |
+
self.mlp = nn.Sequential(
|
49 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
50 |
+
nn.GELU(),
|
51 |
+
nn.Dropout(mlp_drop),
|
52 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
53 |
+
nn.Dropout(mlp_drop),
|
54 |
+
)
|
55 |
+
|
56 |
+
def forward(self, x, cond):
|
57 |
+
# x: [N, L, D]
|
58 |
+
# cond: [N, L_cond, D_cond]
|
59 |
+
x = x + self.cross_attn(self.norm1(x), cond, cond)[0]
|
60 |
+
before_sa = self.norm2(x)
|
61 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa)[0]
|
62 |
+
x = x + self.mlp(self.norm3(x))
|
63 |
+
return x
|
64 |
+
|
65 |
+
|
66 |
+
class TriplaneTransformer(nn.Module):
|
67 |
+
"""
|
68 |
+
Transformer with condition that generates a triplane representation.
|
69 |
+
|
70 |
+
Reference:
|
71 |
+
Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486
|
72 |
+
"""
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
inner_dim: int,
|
76 |
+
image_feat_dim: int,
|
77 |
+
triplane_low_res: int,
|
78 |
+
triplane_high_res: int,
|
79 |
+
triplane_dim: int,
|
80 |
+
num_layers: int,
|
81 |
+
num_heads: int,
|
82 |
+
eps: float = 1e-6,
|
83 |
+
):
|
84 |
+
super().__init__()
|
85 |
+
|
86 |
+
# attributes
|
87 |
+
self.triplane_low_res = triplane_low_res
|
88 |
+
self.triplane_high_res = triplane_high_res
|
89 |
+
self.triplane_dim = triplane_dim
|
90 |
+
|
91 |
+
# modules
|
92 |
+
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
|
93 |
+
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5)
|
94 |
+
self.layers = nn.ModuleList([
|
95 |
+
BasicTransformerBlock(
|
96 |
+
inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps)
|
97 |
+
for _ in range(num_layers)
|
98 |
+
])
|
99 |
+
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
100 |
+
self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0)
|
101 |
+
|
102 |
+
def forward(self, image_feats):
|
103 |
+
# image_feats: [N, L_cond, D_cond]
|
104 |
+
|
105 |
+
N = image_feats.shape[0]
|
106 |
+
H = W = self.triplane_low_res
|
107 |
+
L = 3 * H * W
|
108 |
+
|
109 |
+
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
|
110 |
+
for layer in self.layers:
|
111 |
+
x = layer(x, image_feats)
|
112 |
+
x = self.norm(x)
|
113 |
+
|
114 |
+
# separate each plane and apply deconv
|
115 |
+
x = x.view(N, 3, H, W, -1)
|
116 |
+
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
|
117 |
+
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
|
118 |
+
x = self.deconv(x) # [3*N, D', H', W']
|
119 |
+
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
|
120 |
+
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
|
121 |
+
x = x.contiguous()
|
122 |
+
|
123 |
+
return x
|
src/models/encoder/__init__.py
ADDED
File without changes
|
src/models/encoder/dino.py
ADDED
@@ -0,0 +1,550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch ViT model."""
|
16 |
+
|
17 |
+
|
18 |
+
import collections.abc
|
19 |
+
import math
|
20 |
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
from transformers.activations import ACT2FN
|
26 |
+
from transformers.modeling_outputs import (
|
27 |
+
BaseModelOutput,
|
28 |
+
BaseModelOutputWithPooling,
|
29 |
+
)
|
30 |
+
from transformers import PreTrainedModel, ViTConfig
|
31 |
+
from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
|
32 |
+
|
33 |
+
|
34 |
+
class ViTEmbeddings(nn.Module):
|
35 |
+
"""
|
36 |
+
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None:
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
|
43 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
|
44 |
+
self.patch_embeddings = ViTPatchEmbeddings(config)
|
45 |
+
num_patches = self.patch_embeddings.num_patches
|
46 |
+
self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size))
|
47 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
48 |
+
self.config = config
|
49 |
+
|
50 |
+
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
51 |
+
"""
|
52 |
+
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
53 |
+
resolution images.
|
54 |
+
|
55 |
+
Source:
|
56 |
+
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
57 |
+
"""
|
58 |
+
|
59 |
+
num_patches = embeddings.shape[1] - 1
|
60 |
+
num_positions = self.position_embeddings.shape[1] - 1
|
61 |
+
if num_patches == num_positions and height == width:
|
62 |
+
return self.position_embeddings
|
63 |
+
class_pos_embed = self.position_embeddings[:, 0]
|
64 |
+
patch_pos_embed = self.position_embeddings[:, 1:]
|
65 |
+
dim = embeddings.shape[-1]
|
66 |
+
h0 = height // self.config.patch_size
|
67 |
+
w0 = width // self.config.patch_size
|
68 |
+
# we add a small number to avoid floating point error in the interpolation
|
69 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
70 |
+
h0, w0 = h0 + 0.1, w0 + 0.1
|
71 |
+
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
72 |
+
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
73 |
+
patch_pos_embed = nn.functional.interpolate(
|
74 |
+
patch_pos_embed,
|
75 |
+
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
76 |
+
mode="bicubic",
|
77 |
+
align_corners=False,
|
78 |
+
)
|
79 |
+
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
|
80 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
81 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
82 |
+
|
83 |
+
def forward(
|
84 |
+
self,
|
85 |
+
pixel_values: torch.Tensor,
|
86 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
87 |
+
interpolate_pos_encoding: bool = False,
|
88 |
+
) -> torch.Tensor:
|
89 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
90 |
+
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
|
91 |
+
|
92 |
+
if bool_masked_pos is not None:
|
93 |
+
seq_length = embeddings.shape[1]
|
94 |
+
mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
|
95 |
+
# replace the masked visual tokens by mask_tokens
|
96 |
+
mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
97 |
+
embeddings = embeddings * (1.0 - mask) + mask_tokens * mask
|
98 |
+
|
99 |
+
# add the [CLS] token to the embedded patch tokens
|
100 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
101 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
102 |
+
|
103 |
+
# add positional encoding to each token
|
104 |
+
if interpolate_pos_encoding:
|
105 |
+
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
106 |
+
else:
|
107 |
+
embeddings = embeddings + self.position_embeddings
|
108 |
+
|
109 |
+
embeddings = self.dropout(embeddings)
|
110 |
+
|
111 |
+
return embeddings
|
112 |
+
|
113 |
+
|
114 |
+
class ViTPatchEmbeddings(nn.Module):
|
115 |
+
"""
|
116 |
+
This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
|
117 |
+
`hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a
|
118 |
+
Transformer.
|
119 |
+
"""
|
120 |
+
|
121 |
+
def __init__(self, config):
|
122 |
+
super().__init__()
|
123 |
+
image_size, patch_size = config.image_size, config.patch_size
|
124 |
+
num_channels, hidden_size = config.num_channels, config.hidden_size
|
125 |
+
|
126 |
+
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
|
127 |
+
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
|
128 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
129 |
+
self.image_size = image_size
|
130 |
+
self.patch_size = patch_size
|
131 |
+
self.num_channels = num_channels
|
132 |
+
self.num_patches = num_patches
|
133 |
+
|
134 |
+
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
|
135 |
+
|
136 |
+
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
|
137 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
138 |
+
if num_channels != self.num_channels:
|
139 |
+
raise ValueError(
|
140 |
+
"Make sure that the channel dimension of the pixel values match with the one set in the configuration."
|
141 |
+
f" Expected {self.num_channels} but got {num_channels}."
|
142 |
+
)
|
143 |
+
if not interpolate_pos_encoding:
|
144 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
145 |
+
raise ValueError(
|
146 |
+
f"Input image size ({height}*{width}) doesn't match model"
|
147 |
+
f" ({self.image_size[0]}*{self.image_size[1]})."
|
148 |
+
)
|
149 |
+
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
150 |
+
return embeddings
|
151 |
+
|
152 |
+
|
153 |
+
class ViTSelfAttention(nn.Module):
|
154 |
+
def __init__(self, config: ViTConfig) -> None:
|
155 |
+
super().__init__()
|
156 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
157 |
+
raise ValueError(
|
158 |
+
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
159 |
+
f"heads {config.num_attention_heads}."
|
160 |
+
)
|
161 |
+
|
162 |
+
self.num_attention_heads = config.num_attention_heads
|
163 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
164 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
165 |
+
|
166 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
167 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
168 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
|
169 |
+
|
170 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
171 |
+
|
172 |
+
def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
|
173 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
174 |
+
x = x.view(new_x_shape)
|
175 |
+
return x.permute(0, 2, 1, 3)
|
176 |
+
|
177 |
+
def forward(
|
178 |
+
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
|
179 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
180 |
+
mixed_query_layer = self.query(hidden_states)
|
181 |
+
|
182 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
183 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
184 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
185 |
+
|
186 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
187 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
188 |
+
|
189 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
190 |
+
|
191 |
+
# Normalize the attention scores to probabilities.
|
192 |
+
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
193 |
+
|
194 |
+
# This is actually dropping out entire tokens to attend to, which might
|
195 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
196 |
+
attention_probs = self.dropout(attention_probs)
|
197 |
+
|
198 |
+
# Mask heads if we want to
|
199 |
+
if head_mask is not None:
|
200 |
+
attention_probs = attention_probs * head_mask
|
201 |
+
|
202 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
203 |
+
|
204 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
205 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
206 |
+
context_layer = context_layer.view(new_context_layer_shape)
|
207 |
+
|
208 |
+
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
209 |
+
|
210 |
+
return outputs
|
211 |
+
|
212 |
+
|
213 |
+
class ViTSelfOutput(nn.Module):
|
214 |
+
"""
|
215 |
+
The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the
|
216 |
+
layernorm applied before each block.
|
217 |
+
"""
|
218 |
+
|
219 |
+
def __init__(self, config: ViTConfig) -> None:
|
220 |
+
super().__init__()
|
221 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
222 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
223 |
+
|
224 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
225 |
+
hidden_states = self.dense(hidden_states)
|
226 |
+
hidden_states = self.dropout(hidden_states)
|
227 |
+
|
228 |
+
return hidden_states
|
229 |
+
|
230 |
+
|
231 |
+
class ViTAttention(nn.Module):
|
232 |
+
def __init__(self, config: ViTConfig) -> None:
|
233 |
+
super().__init__()
|
234 |
+
self.attention = ViTSelfAttention(config)
|
235 |
+
self.output = ViTSelfOutput(config)
|
236 |
+
self.pruned_heads = set()
|
237 |
+
|
238 |
+
def prune_heads(self, heads: Set[int]) -> None:
|
239 |
+
if len(heads) == 0:
|
240 |
+
return
|
241 |
+
heads, index = find_pruneable_heads_and_indices(
|
242 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
243 |
+
)
|
244 |
+
|
245 |
+
# Prune linear layers
|
246 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
247 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
248 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
249 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
250 |
+
|
251 |
+
# Update hyper params and store pruned heads
|
252 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
253 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
254 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
255 |
+
|
256 |
+
def forward(
|
257 |
+
self,
|
258 |
+
hidden_states: torch.Tensor,
|
259 |
+
head_mask: Optional[torch.Tensor] = None,
|
260 |
+
output_attentions: bool = False,
|
261 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
262 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions)
|
263 |
+
|
264 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
265 |
+
|
266 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
267 |
+
return outputs
|
268 |
+
|
269 |
+
|
270 |
+
class ViTIntermediate(nn.Module):
|
271 |
+
def __init__(self, config: ViTConfig) -> None:
|
272 |
+
super().__init__()
|
273 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
274 |
+
if isinstance(config.hidden_act, str):
|
275 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
276 |
+
else:
|
277 |
+
self.intermediate_act_fn = config.hidden_act
|
278 |
+
|
279 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
280 |
+
hidden_states = self.dense(hidden_states)
|
281 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
282 |
+
|
283 |
+
return hidden_states
|
284 |
+
|
285 |
+
|
286 |
+
class ViTOutput(nn.Module):
|
287 |
+
def __init__(self, config: ViTConfig) -> None:
|
288 |
+
super().__init__()
|
289 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
290 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
291 |
+
|
292 |
+
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
293 |
+
hidden_states = self.dense(hidden_states)
|
294 |
+
hidden_states = self.dropout(hidden_states)
|
295 |
+
|
296 |
+
hidden_states = hidden_states + input_tensor
|
297 |
+
|
298 |
+
return hidden_states
|
299 |
+
|
300 |
+
|
301 |
+
def modulate(x, shift, scale):
|
302 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
303 |
+
|
304 |
+
|
305 |
+
class ViTLayer(nn.Module):
|
306 |
+
"""This corresponds to the Block class in the timm implementation."""
|
307 |
+
|
308 |
+
def __init__(self, config: ViTConfig) -> None:
|
309 |
+
super().__init__()
|
310 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
311 |
+
self.seq_len_dim = 1
|
312 |
+
self.attention = ViTAttention(config)
|
313 |
+
self.intermediate = ViTIntermediate(config)
|
314 |
+
self.output = ViTOutput(config)
|
315 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
316 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
317 |
+
|
318 |
+
self.adaLN_modulation = nn.Sequential(
|
319 |
+
nn.SiLU(),
|
320 |
+
nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True)
|
321 |
+
)
|
322 |
+
nn.init.constant_(self.adaLN_modulation[-1].weight, 0)
|
323 |
+
nn.init.constant_(self.adaLN_modulation[-1].bias, 0)
|
324 |
+
|
325 |
+
def forward(
|
326 |
+
self,
|
327 |
+
hidden_states: torch.Tensor,
|
328 |
+
adaln_input: torch.Tensor = None,
|
329 |
+
head_mask: Optional[torch.Tensor] = None,
|
330 |
+
output_attentions: bool = False,
|
331 |
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
|
332 |
+
shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)
|
333 |
+
|
334 |
+
self_attention_outputs = self.attention(
|
335 |
+
modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention
|
336 |
+
head_mask,
|
337 |
+
output_attentions=output_attentions,
|
338 |
+
)
|
339 |
+
attention_output = self_attention_outputs[0]
|
340 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
341 |
+
|
342 |
+
# first residual connection
|
343 |
+
hidden_states = attention_output + hidden_states
|
344 |
+
|
345 |
+
# in ViT, layernorm is also applied after self-attention
|
346 |
+
layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp)
|
347 |
+
layer_output = self.intermediate(layer_output)
|
348 |
+
|
349 |
+
# second residual connection is done here
|
350 |
+
layer_output = self.output(layer_output, hidden_states)
|
351 |
+
|
352 |
+
outputs = (layer_output,) + outputs
|
353 |
+
|
354 |
+
return outputs
|
355 |
+
|
356 |
+
|
357 |
+
class ViTEncoder(nn.Module):
|
358 |
+
def __init__(self, config: ViTConfig) -> None:
|
359 |
+
super().__init__()
|
360 |
+
self.config = config
|
361 |
+
self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)])
|
362 |
+
self.gradient_checkpointing = False
|
363 |
+
|
364 |
+
def forward(
|
365 |
+
self,
|
366 |
+
hidden_states: torch.Tensor,
|
367 |
+
adaln_input: torch.Tensor = None,
|
368 |
+
head_mask: Optional[torch.Tensor] = None,
|
369 |
+
output_attentions: bool = False,
|
370 |
+
output_hidden_states: bool = False,
|
371 |
+
return_dict: bool = True,
|
372 |
+
) -> Union[tuple, BaseModelOutput]:
|
373 |
+
all_hidden_states = () if output_hidden_states else None
|
374 |
+
all_self_attentions = () if output_attentions else None
|
375 |
+
|
376 |
+
for i, layer_module in enumerate(self.layer):
|
377 |
+
if output_hidden_states:
|
378 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
379 |
+
|
380 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
381 |
+
|
382 |
+
if self.gradient_checkpointing and self.training:
|
383 |
+
layer_outputs = self._gradient_checkpointing_func(
|
384 |
+
layer_module.__call__,
|
385 |
+
hidden_states,
|
386 |
+
adaln_input,
|
387 |
+
layer_head_mask,
|
388 |
+
output_attentions,
|
389 |
+
)
|
390 |
+
else:
|
391 |
+
layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions)
|
392 |
+
|
393 |
+
hidden_states = layer_outputs[0]
|
394 |
+
|
395 |
+
if output_attentions:
|
396 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
397 |
+
|
398 |
+
if output_hidden_states:
|
399 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
400 |
+
|
401 |
+
if not return_dict:
|
402 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
403 |
+
return BaseModelOutput(
|
404 |
+
last_hidden_state=hidden_states,
|
405 |
+
hidden_states=all_hidden_states,
|
406 |
+
attentions=all_self_attentions,
|
407 |
+
)
|
408 |
+
|
409 |
+
|
410 |
+
class ViTPreTrainedModel(PreTrainedModel):
|
411 |
+
"""
|
412 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
413 |
+
models.
|
414 |
+
"""
|
415 |
+
|
416 |
+
config_class = ViTConfig
|
417 |
+
base_model_prefix = "vit"
|
418 |
+
main_input_name = "pixel_values"
|
419 |
+
supports_gradient_checkpointing = True
|
420 |
+
_no_split_modules = ["ViTEmbeddings", "ViTLayer"]
|
421 |
+
|
422 |
+
def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
|
423 |
+
"""Initialize the weights"""
|
424 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
425 |
+
# Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
|
426 |
+
# `trunc_normal_cpu` not implemented in `half` issues
|
427 |
+
module.weight.data = nn.init.trunc_normal_(
|
428 |
+
module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
|
429 |
+
).to(module.weight.dtype)
|
430 |
+
if module.bias is not None:
|
431 |
+
module.bias.data.zero_()
|
432 |
+
elif isinstance(module, nn.LayerNorm):
|
433 |
+
module.bias.data.zero_()
|
434 |
+
module.weight.data.fill_(1.0)
|
435 |
+
elif isinstance(module, ViTEmbeddings):
|
436 |
+
module.position_embeddings.data = nn.init.trunc_normal_(
|
437 |
+
module.position_embeddings.data.to(torch.float32),
|
438 |
+
mean=0.0,
|
439 |
+
std=self.config.initializer_range,
|
440 |
+
).to(module.position_embeddings.dtype)
|
441 |
+
|
442 |
+
module.cls_token.data = nn.init.trunc_normal_(
|
443 |
+
module.cls_token.data.to(torch.float32),
|
444 |
+
mean=0.0,
|
445 |
+
std=self.config.initializer_range,
|
446 |
+
).to(module.cls_token.dtype)
|
447 |
+
|
448 |
+
|
449 |
+
class ViTModel(ViTPreTrainedModel):
|
450 |
+
def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False):
|
451 |
+
super().__init__(config)
|
452 |
+
self.config = config
|
453 |
+
|
454 |
+
self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token)
|
455 |
+
self.encoder = ViTEncoder(config)
|
456 |
+
|
457 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
458 |
+
self.pooler = ViTPooler(config) if add_pooling_layer else None
|
459 |
+
|
460 |
+
# Initialize weights and apply final processing
|
461 |
+
self.post_init()
|
462 |
+
|
463 |
+
def get_input_embeddings(self) -> ViTPatchEmbeddings:
|
464 |
+
return self.embeddings.patch_embeddings
|
465 |
+
|
466 |
+
def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None:
|
467 |
+
"""
|
468 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
469 |
+
class PreTrainedModel
|
470 |
+
"""
|
471 |
+
for layer, heads in heads_to_prune.items():
|
472 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
473 |
+
|
474 |
+
def forward(
|
475 |
+
self,
|
476 |
+
pixel_values: Optional[torch.Tensor] = None,
|
477 |
+
adaln_input: Optional[torch.Tensor] = None,
|
478 |
+
bool_masked_pos: Optional[torch.BoolTensor] = None,
|
479 |
+
head_mask: Optional[torch.Tensor] = None,
|
480 |
+
output_attentions: Optional[bool] = None,
|
481 |
+
output_hidden_states: Optional[bool] = None,
|
482 |
+
interpolate_pos_encoding: Optional[bool] = None,
|
483 |
+
return_dict: Optional[bool] = None,
|
484 |
+
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
485 |
+
r"""
|
486 |
+
bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*):
|
487 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
488 |
+
"""
|
489 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
490 |
+
output_hidden_states = (
|
491 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
492 |
+
)
|
493 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
494 |
+
|
495 |
+
if pixel_values is None:
|
496 |
+
raise ValueError("You have to specify pixel_values")
|
497 |
+
|
498 |
+
# Prepare head mask if needed
|
499 |
+
# 1.0 in head_mask indicate we keep the head
|
500 |
+
# attention_probs has shape bsz x n_heads x N x N
|
501 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
502 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
503 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
504 |
+
|
505 |
+
# TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?)
|
506 |
+
expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype
|
507 |
+
if pixel_values.dtype != expected_dtype:
|
508 |
+
pixel_values = pixel_values.to(expected_dtype)
|
509 |
+
|
510 |
+
embedding_output = self.embeddings(
|
511 |
+
pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding
|
512 |
+
)
|
513 |
+
|
514 |
+
encoder_outputs = self.encoder(
|
515 |
+
embedding_output,
|
516 |
+
adaln_input=adaln_input,
|
517 |
+
head_mask=head_mask,
|
518 |
+
output_attentions=output_attentions,
|
519 |
+
output_hidden_states=output_hidden_states,
|
520 |
+
return_dict=return_dict,
|
521 |
+
)
|
522 |
+
sequence_output = encoder_outputs[0]
|
523 |
+
sequence_output = self.layernorm(sequence_output)
|
524 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
525 |
+
|
526 |
+
if not return_dict:
|
527 |
+
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
|
528 |
+
return head_outputs + encoder_outputs[1:]
|
529 |
+
|
530 |
+
return BaseModelOutputWithPooling(
|
531 |
+
last_hidden_state=sequence_output,
|
532 |
+
pooler_output=pooled_output,
|
533 |
+
hidden_states=encoder_outputs.hidden_states,
|
534 |
+
attentions=encoder_outputs.attentions,
|
535 |
+
)
|
536 |
+
|
537 |
+
|
538 |
+
class ViTPooler(nn.Module):
|
539 |
+
def __init__(self, config: ViTConfig):
|
540 |
+
super().__init__()
|
541 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
542 |
+
self.activation = nn.Tanh()
|
543 |
+
|
544 |
+
def forward(self, hidden_states):
|
545 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
546 |
+
# to the first token.
|
547 |
+
first_token_tensor = hidden_states[:, 0]
|
548 |
+
pooled_output = self.dense(first_token_tensor)
|
549 |
+
pooled_output = self.activation(pooled_output)
|
550 |
+
return pooled_output
|
src/models/encoder/dino_wrapper.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023, Zexin He
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
|
16 |
+
import torch.nn as nn
|
17 |
+
from transformers import ViTImageProcessor
|
18 |
+
from einops import rearrange, repeat
|
19 |
+
from .dino import ViTModel
|
20 |
+
|
21 |
+
|
22 |
+
class DinoWrapper(nn.Module):
|
23 |
+
"""
|
24 |
+
Dino v1 wrapper using huggingface transformer implementation.
|
25 |
+
"""
|
26 |
+
def __init__(self, model_name: str, freeze: bool = True):
|
27 |
+
super().__init__()
|
28 |
+
self.model, self.processor = self._build_dino(model_name)
|
29 |
+
self.camera_embedder = nn.Sequential(
|
30 |
+
nn.Linear(16, self.model.config.hidden_size, bias=True),
|
31 |
+
nn.SiLU(),
|
32 |
+
nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True)
|
33 |
+
)
|
34 |
+
if freeze:
|
35 |
+
self._freeze()
|
36 |
+
|
37 |
+
def forward(self, image, camera):
|
38 |
+
# image: [B, N, C, H, W]
|
39 |
+
# camera: [B, N, D]
|
40 |
+
# RGB image with [0,1] scale and properly sized
|
41 |
+
if image.ndim == 5:
|
42 |
+
image = rearrange(image, 'b n c h w -> (b n) c h w')
|
43 |
+
dtype = image.dtype
|
44 |
+
inputs = self.processor(
|
45 |
+
images=image.float(),
|
46 |
+
return_tensors="pt",
|
47 |
+
do_rescale=False,
|
48 |
+
do_resize=False,
|
49 |
+
).to(self.model.device).to(dtype)
|
50 |
+
# embed camera
|
51 |
+
N = camera.shape[1]
|
52 |
+
camera_embeddings = self.camera_embedder(camera)
|
53 |
+
camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d')
|
54 |
+
embeddings = camera_embeddings
|
55 |
+
# This resampling of positional embedding uses bicubic interpolation
|
56 |
+
outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True)
|
57 |
+
last_hidden_states = outputs.last_hidden_state
|
58 |
+
return last_hidden_states
|
59 |
+
|
60 |
+
def _freeze(self):
|
61 |
+
print(f"======== Freezing DinoWrapper ========")
|
62 |
+
self.model.eval()
|
63 |
+
for name, param in self.model.named_parameters():
|
64 |
+
param.requires_grad = False
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5):
|
68 |
+
import requests
|
69 |
+
try:
|
70 |
+
model = ViTModel.from_pretrained(model_name, add_pooling_layer=False)
|
71 |
+
processor = ViTImageProcessor.from_pretrained(model_name)
|
72 |
+
return model, processor
|
73 |
+
except requests.exceptions.ProxyError as err:
|
74 |
+
if proxy_error_retries > 0:
|
75 |
+
print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...")
|
76 |
+
import time
|
77 |
+
time.sleep(proxy_error_cooldown)
|
78 |
+
return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown)
|
79 |
+
else:
|
80 |
+
raise err
|