Huiwenshi commited on
Commit
0514ca2
1 Parent(s): 0e1fde2

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +3 -4
  2. README.md +12 -0
  3. app_hg.py +234 -125
  4. third_party/dust3r/.gitignore +132 -0
  5. third_party/dust3r/.gitmodules +3 -0
  6. third_party/dust3r/LICENSE +7 -0
  7. third_party/dust3r/NOTICE +12 -0
  8. third_party/dust3r/README.md +390 -0
  9. third_party/dust3r/croco/LICENSE +52 -0
  10. third_party/dust3r/croco/NOTICE +21 -0
  11. third_party/dust3r/croco/README.MD +124 -0
  12. third_party/dust3r/croco/datasets/__init__.py +0 -0
  13. third_party/dust3r/croco/datasets/crops/README.MD +104 -0
  14. third_party/dust3r/croco/datasets/crops/extract_crops_from_images.py +159 -0
  15. third_party/dust3r/croco/datasets/habitat_sim/README.MD +76 -0
  16. third_party/dust3r/croco/datasets/habitat_sim/__init__.py +0 -0
  17. third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py +92 -0
  18. third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py +27 -0
  19. third_party/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py +177 -0
  20. third_party/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py +390 -0
  21. third_party/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py +69 -0
  22. third_party/dust3r/croco/datasets/habitat_sim/paths.py +129 -0
  23. third_party/dust3r/croco/datasets/pairs_dataset.py +109 -0
  24. third_party/dust3r/croco/datasets/transforms.py +95 -0
  25. third_party/dust3r/croco/demo.py +55 -0
  26. third_party/dust3r/croco/models/blocks.py +241 -0
  27. third_party/dust3r/croco/models/criterion.py +37 -0
  28. third_party/dust3r/croco/models/croco.py +249 -0
  29. third_party/dust3r/croco/models/croco_downstream.py +122 -0
  30. third_party/dust3r/croco/models/curope/__init__.py +4 -0
  31. third_party/dust3r/croco/models/curope/curope.cpp +69 -0
  32. third_party/dust3r/croco/models/curope/curope2d.py +40 -0
  33. third_party/dust3r/croco/models/curope/kernels.cu +108 -0
  34. third_party/dust3r/croco/models/curope/setup.py +34 -0
  35. third_party/dust3r/croco/models/dpt_block.py +450 -0
  36. third_party/dust3r/croco/models/head_downstream.py +58 -0
  37. third_party/dust3r/croco/models/masking.py +25 -0
  38. third_party/dust3r/croco/models/pos_embed.py +159 -0
  39. third_party/dust3r/croco/pretrain.py +254 -0
  40. third_party/dust3r/croco/stereoflow/README.MD +318 -0
  41. third_party/dust3r/croco/stereoflow/augmentor.py +290 -0
  42. third_party/dust3r/croco/stereoflow/criterion.py +251 -0
  43. third_party/dust3r/croco/stereoflow/datasets_flow.py +630 -0
  44. third_party/dust3r/croco/stereoflow/datasets_stereo.py +674 -0
  45. third_party/dust3r/croco/stereoflow/download_model.sh +12 -0
  46. third_party/dust3r/croco/stereoflow/engine.py +280 -0
  47. third_party/dust3r/croco/stereoflow/test.py +216 -0
  48. third_party/dust3r/croco/stereoflow/train.py +253 -0
  49. third_party/dust3r/croco/utils/misc.py +463 -0
  50. third_party/dust3r/datasets_preprocess/habitat/README.md +66 -0
.gitignore CHANGED
@@ -32,11 +32,10 @@
32
  ### PreCI ###
33
  .codecc
34
 
35
- app_hg.py
36
  outputs
37
  weights
38
  .vscode/
39
- baking
40
  inference.py
41
- third_party/weights
42
- third_party/dust3r
 
 
32
  ### PreCI ###
33
  .codecc
34
 
 
35
  outputs
36
  weights
37
  .vscode/
 
38
  inference.py
39
+ # third_party/weights
40
+ # third_party/dust3r
41
+ # app_hg.py
README.md CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  [English](README.md) | [简体中文](README_zh_cn.md)
2
 
3
  <!-- ## **Hunyuan3D-1.0** -->
 
1
+ ---
2
+ title: Hunyuan3D-1.0
3
+ emoji: 😻
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.5.0
8
+ app_file: app_hg.py
9
+ pinned: false
10
+ short_description: Text-to-3D and Image-to-3D Generation
11
+ ---
12
+
13
  [English](README.md) | [简体中文](README_zh_cn.md)
14
 
15
  <!-- ## **Hunyuan3D-1.0** -->
app_hg.py CHANGED
@@ -23,7 +23,6 @@
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
  import spaces
25
  import os
26
- os.environ['CUDA_HOME'] = '/usr/local/cuda-11*'
27
  import warnings
28
  import argparse
29
  import gradio as gr
@@ -33,10 +32,22 @@ import torch
33
  import numpy as np
34
  from PIL import Image
35
  from einops import rearrange
 
36
  from huggingface_hub import snapshot_download
37
 
38
  from infer import seed_everything, save_gif
39
  from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  warnings.simplefilter('ignore', category=UserWarning)
42
  warnings.simplefilter('ignore', category=FutureWarning)
@@ -47,43 +58,14 @@ parser.add_argument("--use_lite", default=False, action="store_true")
47
  parser.add_argument("--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str)
48
  parser.add_argument("--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str)
49
  parser.add_argument("--text2image_path", default="weights/hunyuanDiT", type=str)
50
- parser.add_argument("--save_memory", default=False) # , action="store_true")
51
  parser.add_argument("--device", default="cuda:0", type=str)
52
  args = parser.parse_args()
53
 
54
- @spaces.GPU
55
- def find_cuda():
56
- # Check if CUDA_HOME or CUDA_PATH environment variables are set
57
- cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH')
58
-
59
- if cuda_home and os.path.exists(cuda_home):
60
- return cuda_home
61
-
62
- # Search for the nvcc executable in the system's PATH
63
- nvcc_path = shutil.which('nvcc')
64
-
65
- if nvcc_path:
66
- # Remove the 'bin/nvcc' part to get the CUDA installation path
67
- cuda_path = os.path.dirname(os.path.dirname(nvcc_path))
68
- return cuda_path
69
-
70
- return None
71
-
72
- cuda_path = find_cuda()
73
-
74
- if cuda_path:
75
- print(f"CUDA installation found at: {cuda_path}")
76
- else:
77
- print("CUDA installation not found")
78
-
79
-
80
-
81
  def download_models():
82
- # Create weights directory if it doesn't exist
83
  os.makedirs("weights", exist_ok=True)
84
  os.makedirs("weights/hunyuanDiT", exist_ok=True)
85
-
86
- # Download Hunyuan3D-1 model
87
  try:
88
  snapshot_download(
89
  repo_id="tencent/Hunyuan3D-1",
@@ -93,8 +75,6 @@ def download_models():
93
  print("Successfully downloaded Hunyuan3D-1 model")
94
  except Exception as e:
95
  print(f"Error downloading Hunyuan3D-1: {e}")
96
-
97
- # Download HunyuanDiT model
98
  try:
99
  snapshot_download(
100
  repo_id="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled",
@@ -109,25 +89,27 @@ def download_models():
109
  download_models()
110
 
111
  ################################################################
 
 
112
 
113
- CONST_PORT = 8080
114
- CONST_MAX_QUEUE = 1
115
- CONST_SERVER = '0.0.0.0'
116
 
117
  CONST_HEADER = '''
118
  <h2><a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'><b>Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation</b></a></h2>
119
  ⭐️Technical report: <a href='https://arxiv.org/pdf/2411.02293' target='_blank'>ArXiv</a>. ⭐️Code: <a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>GitHub</a>.
 
120
 
121
- ❗️❗️❗️**Important Notes**
 
122
 
123
- Our demo allows you to export models in various formats:
124
- - By default, export as a *.obj mesh with vertex colors or a *.glb mesh.
125
- - Select "texture mapping" to export a *.obj mesh with a texture map or a *.glb mesh.
126
- - Select "render GIF" to export a GIF rendering of the *.glb file.
127
 
128
- If the results aren't satisfactory, try using a different seed value (default is 0).
129
  '''
130
 
 
 
131
  ################################################################
132
 
133
  def get_example_img_list():
@@ -143,6 +125,9 @@ def get_example_txt_list():
143
 
144
  example_is = get_example_img_list()
145
  example_ts = get_example_txt_list()
 
 
 
146
  ################################################################
147
 
148
  worker_xbg = Removebg()
@@ -166,6 +151,12 @@ worker_v23 = Views2Mesh(
166
  )
167
  worker_gif = GifRenderer(args.device)
168
 
 
 
 
 
 
 
169
  @spaces.GPU
170
  def stage_0_t2i(text, image, seed, step):
171
  os.makedirs('./outputs/app_output', exist_ok=True)
@@ -177,7 +168,7 @@ def stage_0_t2i(text, image, seed, step):
177
  save_folder = f'./outputs/app_output/{cur_id}'
178
  os.makedirs(save_folder, exist_ok=True)
179
 
180
- dst = os.path.join(save_folder, 'img.png')
181
 
182
  if not text:
183
  if image is None:
@@ -190,16 +181,16 @@ def stage_0_t2i(text, image, seed, step):
190
  image.save(dst)
191
  dst = worker_xbg(image, save_folder)
192
  return dst, save_folder
193
-
194
  @spaces.GPU
195
- def stage_1_xbg(image, save_folder):
196
  if isinstance(image, str):
197
  image = Image.open(image)
198
  dst = save_folder + '/img_nobg.png'
199
- rgba = worker_xbg(image)
200
  rgba.save(dst)
201
  return dst
202
-
203
  @spaces.GPU
204
  def stage_2_i2v(image, seed, step, save_folder):
205
  if isinstance(image, str):
@@ -222,12 +213,9 @@ def stage_3_v23(
222
  seed,
223
  save_folder,
224
  target_face_count = 30000,
225
- do_texture_mapping = True,
226
- do_render =True
227
  ):
228
- do_texture_mapping = do_texture_mapping or do_render
229
- obj_dst = save_folder + '/mesh_with_colors.obj'
230
- glb_dst = save_folder + '/mesh.glb'
231
  worker_v23(
232
  views_pil,
233
  cond_pil,
@@ -236,95 +224,197 @@ def stage_3_v23(
236
  target_face_count = target_face_count,
237
  do_texture_mapping = do_texture_mapping
238
  )
 
 
 
239
  return obj_dst, glb_dst
240
 
241
  @spaces.GPU
242
- def stage_4_gif(obj_dst, save_folder, do_render_gif=True):
243
- if not do_render_gif: return None
244
- gif_dst = save_folder + '/output.gif'
245
- worker_gif(
246
- save_folder + '/mesh.obj',
247
- gif_dst_path = gif_dst
248
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  return gif_dst
250
 
251
- #===============================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  with gr.Blocks() as demo:
253
  gr.Markdown(CONST_HEADER)
254
  with gr.Row(variant="panel"):
 
 
 
255
  with gr.Column(scale=2):
 
 
 
256
  with gr.Tab("Text to 3D"):
257
  with gr.Column():
258
- text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。', lines=1, max_lines=10, label='Input text')
 
259
  with gr.Row():
260
- textgen_seed = gr.Number(value=0, label="T2I seed", precision=0)
261
- textgen_step = gr.Number(value=25, label="T2I steps", precision=0, minimum=10, maximum=50)
262
- textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
263
- textgen_STEP = gr.Number(value=50, label="Gen steps", precision=0, minimum=40, maximum=100)
264
- textgen_max_faces = gr.Number(value=90000, label="Face number", precision=0, minimum=5000, maximum=1000000)
265
-
 
 
 
 
 
 
 
 
 
 
 
 
 
266
  with gr.Row():
267
- # textgen_do_texture_mapping = gr.Checkbox(label="Texture mapping", value=False, interactive=True)
268
- # textgen_do_render_gif = gr.Checkbox(label="Render GIF", value=False, interactive=True)
269
  textgen_submit = gr.Button("Generate", variant="primary")
270
 
271
  with gr.Row():
272
  gr.Examples(examples=example_ts, inputs=[text], label="Text examples", examples_per_page=10)
273
 
 
 
274
  with gr.Tab("Image to 3D"):
275
- with gr.Column():
276
- input_image = gr.Image(label="Input image",
277
- width=256, height=256, type="pil",
278
- image_mode="RGBA", sources="upload",
279
- interactive=True)
280
- with gr.Row():
281
- imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0)
282
- imggen_STEP = gr.Number(value=50, label="Gen steps", precision=0, minimum=40, maximum=100)
283
- imggen_max_faces = gr.Number(value=90000, label="Face number", precision=0, minimum=5000, maximum=1000000)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
- with gr.Row():
286
- # imggen_do_texture_mapping = gr.Checkbox(label="Texture mapping", value=False, interactive=True)
287
- # imggen_do_render_gif = gr.Checkbox(label="Render GIF", value=False, interactive=True)
288
- imggen_submit = gr.Button("Generate", variant="primary")
289
- with gr.Row():
290
- gr.Examples(examples=example_is, inputs=[input_image], label="Img examples", examples_per_page=10)
291
-
292
  with gr.Column(scale=3):
293
  with gr.Row():
294
  with gr.Column(scale=2):
295
- rem_bg_image = gr.Image(label="Image without background", type="pil",
296
- image_mode="RGBA", interactive=False)
 
 
 
 
297
  with gr.Column(scale=3):
298
- result_image = gr.Image(label="Multi-view images", type="pil", interactive=False)
299
-
300
- with gr.Row():
 
 
 
 
301
  result_3dobj = gr.Model3D(
302
  clear_color=[0.0, 0.0, 0.0, 0.0],
303
- label="OBJ",
304
  show_label=True,
305
  visible=True,
306
  camera_position=[90, 90, None],
307
  interactive=False
308
  )
 
 
 
 
 
 
 
 
 
 
309
 
310
- result_3dglb = gr.Model3D(
311
  clear_color=[0.0, 0.0, 0.0, 0.0],
312
- label="GLB",
313
  show_label=True,
314
  visible=True,
315
  camera_position=[90, 90, None],
316
- interactive=False
317
- )
318
- # result_gif = gr.Image(label="Rendered GIF", interactive=False)
319
 
320
- with gr.Row():
321
- gr.Markdown("""Due to Gradio limitations, OBJ files are displayed with vertex shading only, while GLB files can be viewed with texture shading. For the best experience, we recommend downloading the GLB files and opening them with 3D software like Blender or MeshLab.""")
322
-
323
- #===============================================================
324
- textgen_do_texture_mapping = gr.State(False)
325
- textgen_do_render_gif = gr.State(False)
326
- imggen_do_texture_mapping = gr.State(False)
327
- imggen_do_render_gif = gr.State(False)
 
 
 
328
 
329
  none = gr.State(None)
330
  save_folder = gr.State()
@@ -332,41 +422,60 @@ with gr.Blocks() as demo:
332
  views_image = gr.State()
333
  text_image = gr.State()
334
 
 
335
  textgen_submit.click(
336
- fn=stage_0_t2i, inputs=[text, none, textgen_seed, textgen_step],
 
337
  outputs=[rem_bg_image, save_folder],
338
  ).success(
339
- fn=stage_2_i2v, inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder],
 
340
  outputs=[views_image, cond_image, result_image],
341
  ).success(
342
- fn=stage_3_v23, inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces, textgen_do_texture_mapping, textgen_do_render_gif],
343
- outputs=[result_3dobj, result_3dglb],
 
 
 
 
 
 
 
 
 
344
  ).success(lambda: print('Text_to_3D Done ...'))
345
- # .success(
346
- # fn=stage_4_gif, inputs=[result_3dglb, save_folder, textgen_do_render_gif],
347
- # outputs=[result_gif],
348
- # ).success(lambda: print('Text_to_3D Done ...'))
349
 
 
350
  imggen_submit.click(
351
- fn=stage_0_t2i, inputs=[none, input_image, textgen_seed, textgen_step],
 
352
  outputs=[text_image, save_folder],
353
  ).success(
354
- fn=stage_1_xbg, inputs=[text_image, save_folder],
 
355
  outputs=[rem_bg_image],
356
  ).success(
357
- fn=stage_2_i2v, inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder],
 
358
  outputs=[views_image, cond_image, result_image],
359
  ).success(
360
- fn=stage_3_v23, inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces, imggen_do_texture_mapping, imggen_do_render_gif],
361
- outputs=[result_3dobj, result_3dglb],
 
 
 
 
 
 
 
 
 
362
  ).success(lambda: print('Image_to_3D Done ...'))
363
- # success(
364
- # fn=stage_4_gif, inputs=[result_3dglb, save_folder, imggen_do_render_gif],
365
- # outputs=[result_gif],
366
- # ).success(lambda: print('Image_to_3D Done ...'))
367
 
368
- #===============================================================
 
 
369
 
370
- demo.queue()
371
- demo.launch()
372
 
 
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
  import spaces
25
  import os
 
26
  import warnings
27
  import argparse
28
  import gradio as gr
 
32
  import numpy as np
33
  from PIL import Image
34
  from einops import rearrange
35
+ import pandas as pd
36
  from huggingface_hub import snapshot_download
37
 
38
  from infer import seed_everything, save_gif
39
  from infer import Text2Image, Removebg, Image2Views, Views2Mesh, GifRenderer
40
+ from third_party.check import check_bake_available
41
+
42
+ try:
43
+ from third_party.mesh_baker import MeshBaker
44
+ BAKE_AVAILEBLE = True
45
+ except Exception as err:
46
+ print(err)
47
+ print("import baking related fail, run without baking")
48
+ check_bake_available()
49
+ BAKE_AVAILEBLE = False
50
+
51
 
52
  warnings.simplefilter('ignore', category=UserWarning)
53
  warnings.simplefilter('ignore', category=FutureWarning)
 
58
  parser.add_argument("--mv23d_cfg_path", default="./svrm/configs/svrm.yaml", type=str)
59
  parser.add_argument("--mv23d_ckt_path", default="weights/svrm/svrm.safetensors", type=str)
60
  parser.add_argument("--text2image_path", default="weights/hunyuanDiT", type=str)
61
+ parser.add_argument("--save_memory", default=False, action="store_true")
62
  parser.add_argument("--device", default="cuda:0", type=str)
63
  args = parser.parse_args()
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def download_models():
 
66
  os.makedirs("weights", exist_ok=True)
67
  os.makedirs("weights/hunyuanDiT", exist_ok=True)
68
+ os.makedirs("third_party/weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt", exist_ok=True)
 
69
  try:
70
  snapshot_download(
71
  repo_id="tencent/Hunyuan3D-1",
 
75
  print("Successfully downloaded Hunyuan3D-1 model")
76
  except Exception as e:
77
  print(f"Error downloading Hunyuan3D-1: {e}")
 
 
78
  try:
79
  snapshot_download(
80
  repo_id="Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers-Distilled",
 
89
  download_models()
90
 
91
  ################################################################
92
+ # initial setting
93
+ ################################################################
94
 
 
 
 
95
 
96
  CONST_HEADER = '''
97
  <h2><a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'><b>Tencent Hunyuan3D-1.0: A Unified Framework for Text-to-3D and Image-to-3D Generation</b></a></h2>
98
  ⭐️Technical report: <a href='https://arxiv.org/pdf/2411.02293' target='_blank'>ArXiv</a>. ⭐️Code: <a href='https://github.com/tencent/Hunyuan3D-1' target='_blank'>GitHub</a>.
99
+ '''
100
 
101
+ CONST_NOTE = '''
102
+ ❗️❗️❗️Usage❗️❗️❗️<br>
103
 
104
+ Limited by format, the model can only export *.obj mesh with vertex colors. The "texture" mod can only work on *.glb.<br>
105
+ Please click "Do Rendering" to export a GIF.<br>
106
+ You can click "Do Baking" to bake multi-view imgaes onto the shape.<br>
 
107
 
108
+ If the results aren't satisfactory, please try a different radnom seed (default is 0).
109
  '''
110
 
111
+ ################################################################
112
+ # prepare text examples and image examples
113
  ################################################################
114
 
115
  def get_example_img_list():
 
125
 
126
  example_is = get_example_img_list()
127
  example_ts = get_example_txt_list()
128
+
129
+ ################################################################
130
+ # initial models
131
  ################################################################
132
 
133
  worker_xbg = Removebg()
 
151
  )
152
  worker_gif = GifRenderer(args.device)
153
 
154
+
155
+ if BAKE_AVAILEBLE:
156
+ worker_baker = MeshBaker()
157
+
158
+
159
+ ### functional modules
160
  @spaces.GPU
161
  def stage_0_t2i(text, image, seed, step):
162
  os.makedirs('./outputs/app_output', exist_ok=True)
 
168
  save_folder = f'./outputs/app_output/{cur_id}'
169
  os.makedirs(save_folder, exist_ok=True)
170
 
171
+ dst = save_folder + '/img.png'
172
 
173
  if not text:
174
  if image is None:
 
181
  image.save(dst)
182
  dst = worker_xbg(image, save_folder)
183
  return dst, save_folder
184
+
185
  @spaces.GPU
186
+ def stage_1_xbg(image, save_folder, force_remove):
187
  if isinstance(image, str):
188
  image = Image.open(image)
189
  dst = save_folder + '/img_nobg.png'
190
+ rgba = worker_xbg(image, force=force_remove)
191
  rgba.save(dst)
192
  return dst
193
+
194
  @spaces.GPU
195
  def stage_2_i2v(image, seed, step, save_folder):
196
  if isinstance(image, str):
 
213
  seed,
214
  save_folder,
215
  target_face_count = 30000,
216
+ texture_color = 'texture'
 
217
  ):
218
+ do_texture_mapping = texture_color == 'texture'
 
 
219
  worker_v23(
220
  views_pil,
221
  cond_pil,
 
224
  target_face_count = target_face_count,
225
  do_texture_mapping = do_texture_mapping
226
  )
227
+ glb_dst = save_folder + '/mesh.glb' if do_texture_mapping else None
228
+ obj_dst = save_folder + '/mesh.obj'
229
+ obj_dst = save_folder + '/mesh_vertex_colors.obj' # gradio just only can show vertex shading
230
  return obj_dst, glb_dst
231
 
232
  @spaces.GPU
233
+ def stage_3p_baking(save_folder, color, bake):
234
+ if color == "texture" and bake:
235
+ obj_dst = worker_baker(save_folder)
236
+ glb_dst = obj_dst.replace(".obj", ".glb")
237
+ return glb_dst
238
+ else:
239
+ return None
240
+
241
+ @spaces.GPU
242
+ def stage_4_gif(save_folder, color, bake, render):
243
+ if not render: return None
244
+ if os.path.exists(save_folder + '/view_1/bake/mesh.obj'):
245
+ obj_dst = save_folder + '/view_1/bake/mesh.obj'
246
+ elif os.path.exists(save_folder + '/view_0/bake/mesh.obj'):
247
+ obj_dst = save_folder + '/view_0/bake/mesh.obj'
248
+ elif os.path.exists(save_folder + '/mesh.obj'):
249
+ obj_dst = save_folder + '/mesh.obj'
250
+ else:
251
+ print(save_folder)
252
+ raise FileNotFoundError("mesh obj file not found")
253
+ gif_dst = obj_dst.replace(".obj", ".gif")
254
+ worker_gif(obj_dst, gif_dst_path=gif_dst)
255
  return gif_dst
256
 
257
+
258
+ def check_image_available(image):
259
+ if image.mode == "RGBA":
260
+ data = np.array(image)
261
+ alpha_channel = data[:, :, 3]
262
+ unique_alpha_values = np.unique(alpha_channel)
263
+ if len(unique_alpha_values) == 1:
264
+ msg = "The alpha channel is missing or invalid. The background removal option is selected for you."
265
+ return msg, gr.update(value=True, interactive=False)
266
+ else:
267
+ msg = "The image has four channels, and you can choose to remove the background or not."
268
+ return msg, gr.update(value=False, interactive=True)
269
+ elif image.mode == "RGB":
270
+ msg = "The alpha channel is missing or invalid. The background removal option is selected for you."
271
+ return msg, gr.update(value=True, interactive=False)
272
+ else:
273
+ raise Exception("Image Error")
274
+
275
+ def update_bake_render(color):
276
+ if color == "vertex":
277
+ return gr.update(value=False, interactive=False), gr.update(value=False, interactive=False)
278
+ else:
279
+ return gr.update(interactive=True), gr.update(interactive=True)
280
+
281
+ # ===============================================================
282
+ # gradio display
283
+ # ===============================================================
284
+
285
  with gr.Blocks() as demo:
286
  gr.Markdown(CONST_HEADER)
287
  with gr.Row(variant="panel"):
288
+
289
+ ###### Input region
290
+
291
  with gr.Column(scale=2):
292
+
293
+ ### Text iutput region
294
+
295
  with gr.Tab("Text to 3D"):
296
  with gr.Column():
297
+ text = gr.TextArea('一只黑白相间的熊猫在白色背景上居中坐着,呈现出卡通风格和可爱氛围。',
298
+ lines=3, max_lines=20, label='Input text')
299
  with gr.Row():
300
+ textgen_color = gr.Radio(choices=["vertex", "texture"], label="Color", value="texture")
301
+ with gr.Row():
302
+ textgen_render = gr.Checkbox(label="Do Rendering", value=True, interactive=True)
303
+ if BAKE_AVAILEBLE:
304
+ textgen_bake = gr.Checkbox(label="Do Baking", value=True, interactive=True)
305
+ else:
306
+ textgen_bake = gr.Checkbox(label="Do Baking", value=False, interactive=False)
307
+
308
+ textgen_color.change(fn=update_bake_render, inputs=textgen_color, outputs=[textgen_bake, textgen_render])
309
+
310
+ with gr.Row():
311
+ textgen_seed = gr.Number(value=0, label="T2I seed", precision=0, interactive=True)
312
+ textgen_step = gr.Number(value=25, label="T2I steps", precision=0,
313
+ minimum=10, maximum=50, interactive=True)
314
+ textgen_SEED = gr.Number(value=0, label="Gen seed", precision=0, interactive=True)
315
+ textgen_STEP = gr.Number(value=50, label="Gen steps", precision=0,
316
+ minimum=40, maximum=100, interactive=True)
317
+ textgen_max_faces = gr.Number(value=90000, label="Face number", precision=0,
318
+ minimum=5000, maximum=1000000, interactive=True)
319
  with gr.Row():
 
 
320
  textgen_submit = gr.Button("Generate", variant="primary")
321
 
322
  with gr.Row():
323
  gr.Examples(examples=example_ts, inputs=[text], label="Text examples", examples_per_page=10)
324
 
325
+ ### Image iutput region
326
+
327
  with gr.Tab("Image to 3D"):
328
+ with gr.Row():
329
+ input_image = gr.Image(label="Input image", width=256, height=256, type="pil",
330
+ image_mode="RGBA", sources="upload", interactive=True)
331
+ with gr.Row():
332
+ alert_message = gr.Markdown("") # for warning
333
+ with gr.Row():
334
+ imggen_color = gr.Radio(choices=["vertex", "texture"], label="Color", value="texture")
335
+ with gr.Row():
336
+ imggen_removebg = gr.Checkbox(label="Remove Background", value=True, interactive=True)
337
+ imggen_render = gr.Checkbox(label="Do Rendering", value=True, interactive=True)
338
+ if BAKE_AVAILEBLE:
339
+ imggen_bake = gr.Checkbox(label="Do Baking", value=True, interactive=True)
340
+ else:
341
+ imggen_bake = gr.Checkbox(label="Do Baking", value=False, interactive=False)
342
+
343
+ input_image.change(fn=check_image_available, inputs=input_image, outputs=[alert_message, imggen_removebg])
344
+ imggen_color.change(fn=update_bake_render, inputs=imggen_color, outputs=[imggen_bake, imggen_render])
345
+
346
+ with gr.Row():
347
+ imggen_SEED = gr.Number(value=0, label="Gen seed", precision=0, interactive=True)
348
+ imggen_STEP = gr.Number(value=50, label="Gen steps", precision=0,
349
+ minimum=40, maximum=100, interactive=True)
350
+ imggen_max_faces = gr.Number(value=90000, label="Face number", precision=0,
351
+ minimum=5000, maximum=1000000, interactive=True)
352
+ with gr.Row():
353
+ imggen_submit = gr.Button("Generate", variant="primary")
354
+
355
+ with gr.Row():
356
+ gr.Examples(examples=example_is, inputs=[input_image],
357
+ label="Img examples", examples_per_page=10)
358
+
359
+ gr.Markdown(CONST_NOTE)
360
+
361
+ ###### Output region
362
 
 
 
 
 
 
 
 
363
  with gr.Column(scale=3):
364
  with gr.Row():
365
  with gr.Column(scale=2):
366
+ rem_bg_image = gr.Image(
367
+ label="Image without background",
368
+ type="pil",
369
+ image_mode="RGBA",
370
+ interactive=False
371
+ )
372
  with gr.Column(scale=3):
373
+ result_image = gr.Image(
374
+ label="Multi-view images",
375
+ type="pil",
376
+ interactive=False
377
+ )
378
+
379
+ with gr.Row():
380
  result_3dobj = gr.Model3D(
381
  clear_color=[0.0, 0.0, 0.0, 0.0],
382
+ label="OBJ vertex color",
383
  show_label=True,
384
  visible=True,
385
  camera_position=[90, 90, None],
386
  interactive=False
387
  )
388
+ result_gif = gr.Image(label="GIF", interactive=False)
389
+
390
+ with gr.Row():
391
+ result_3dglb_texture = gr.Model3D(
392
+ clear_color=[0.0, 0.0, 0.0, 0.0],
393
+ label="GLB texture color",
394
+ show_label=True,
395
+ visible=True,
396
+ camera_position=[90, 90, None],
397
+ interactive=False)
398
 
399
+ result_3dglb_baked = gr.Model3D(
400
  clear_color=[0.0, 0.0, 0.0, 0.0],
401
+ label="GLB baked color",
402
  show_label=True,
403
  visible=True,
404
  camera_position=[90, 90, None],
405
+ interactive=False)
 
 
406
 
407
+ with gr.Row():
408
+ gr.Markdown(
409
+ "Due to Gradio limitations, OBJ files are displayed with vertex shading only, "
410
+ "while GLB files can be viewed with texture shading. <br>For the best experience, "
411
+ "we recommend downloading the GLB files and opening them with 3D software "
412
+ "like Blender or MeshLab."
413
+ )
414
+
415
+ #===============================================================
416
+ # gradio running code
417
+ #===============================================================
418
 
419
  none = gr.State(None)
420
  save_folder = gr.State()
 
422
  views_image = gr.State()
423
  text_image = gr.State()
424
 
425
+
426
  textgen_submit.click(
427
+ fn=stage_0_t2i,
428
+ inputs=[text, none, textgen_seed, textgen_step],
429
  outputs=[rem_bg_image, save_folder],
430
  ).success(
431
+ fn=stage_2_i2v,
432
+ inputs=[rem_bg_image, textgen_SEED, textgen_STEP, save_folder],
433
  outputs=[views_image, cond_image, result_image],
434
  ).success(
435
+ fn=stage_3_v23,
436
+ inputs=[views_image, cond_image, textgen_SEED, save_folder, textgen_max_faces, textgen_color],
437
+ outputs=[result_3dobj, result_3dglb_texture],
438
+ ).success(
439
+ fn=stage_3p_baking,
440
+ inputs=[save_folder, textgen_color, textgen_bake],
441
+ outputs=[result_3dglb_baked],
442
+ ).success(
443
+ fn=stage_4_gif,
444
+ inputs=[save_folder, textgen_color, textgen_bake, textgen_render],
445
+ outputs=[result_gif],
446
  ).success(lambda: print('Text_to_3D Done ...'))
 
 
 
 
447
 
448
+
449
  imggen_submit.click(
450
+ fn=stage_0_t2i,
451
+ inputs=[none, input_image, textgen_seed, textgen_step],
452
  outputs=[text_image, save_folder],
453
  ).success(
454
+ fn=stage_1_xbg,
455
+ inputs=[text_image, save_folder, imggen_removebg],
456
  outputs=[rem_bg_image],
457
  ).success(
458
+ fn=stage_2_i2v,
459
+ inputs=[rem_bg_image, imggen_SEED, imggen_STEP, save_folder],
460
  outputs=[views_image, cond_image, result_image],
461
  ).success(
462
+ fn=stage_3_v23,
463
+ inputs=[views_image, cond_image, imggen_SEED, save_folder, imggen_max_faces, imggen_color],
464
+ outputs=[result_3dobj, result_3dglb_texture],
465
+ ).success(
466
+ fn=stage_3p_baking,
467
+ inputs=[save_folder, imggen_color, imggen_bake],
468
+ outputs=[result_3dglb_baked],
469
+ ).success(
470
+ fn=stage_4_gif,
471
+ inputs=[save_folder, imggen_color, imggen_bake, imggen_render],
472
+ outputs=[result_gif],
473
  ).success(lambda: print('Image_to_3D Done ...'))
 
 
 
 
474
 
475
+ #===============================================================
476
+ # start gradio server
477
+ #===============================================================
478
 
479
+ demo.queue(max_size=CONST_MAX_QUEUE)
480
+ demo.launch(server_name=CONST_SERVER, server_port=CONST_PORT)
481
 
third_party/dust3r/.gitignore ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data/
2
+ checkpoints/
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ pip-wheel-metadata/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ target/
79
+
80
+ # Jupyter Notebook
81
+ .ipynb_checkpoints
82
+
83
+ # IPython
84
+ profile_default/
85
+ ipython_config.py
86
+
87
+ # pyenv
88
+ .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
third_party/dust3r/.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "croco"]
2
+ path = croco
3
+ url = https://github.com/naver/croco
third_party/dust3r/LICENSE ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ DUSt3R, Copyright (c) 2024-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
third_party/dust3r/NOTICE ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ DUSt3R
2
+ Copyright 2024-present NAVER Corp.
3
+
4
+ This project contains subcomponents with separate copyright notices and license terms.
5
+ Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6
+
7
+ ====
8
+
9
+ naver/croco
10
+ https://github.com/naver/croco/
11
+
12
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0
third_party/dust3r/README.md ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![demo](assets/dust3r.jpg)
2
+
3
+ Official implementation of `DUSt3R: Geometric 3D Vision Made Easy`
4
+ [[Project page](https://dust3r.europe.naverlabs.com/)], [[DUSt3R arxiv](https://arxiv.org/abs/2312.14132)]
5
+
6
+ > **Make sure to also check [MASt3R](https://github.com/naver/mast3r): Our new model with a local feature head, metric pointmaps, and a more scalable global alignment!**
7
+
8
+ ![Example of reconstruction from two images](assets/pipeline1.jpg)
9
+
10
+ ![High level overview of DUSt3R capabilities](assets/dust3r_archi.jpg)
11
+
12
+ ```bibtex
13
+ @inproceedings{dust3r_cvpr24,
14
+ title={DUSt3R: Geometric 3D Vision Made Easy},
15
+ author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},
16
+ booktitle = {CVPR},
17
+ year = {2024}
18
+ }
19
+
20
+ @misc{dust3r_arxiv23,
21
+ title={DUSt3R: Geometric 3D Vision Made Easy},
22
+ author={Shuzhe Wang and Vincent Leroy and Yohann Cabon and Boris Chidlovskii and Jerome Revaud},
23
+ year={2023},
24
+ eprint={2312.14132},
25
+ archivePrefix={arXiv},
26
+ primaryClass={cs.CV}
27
+ }
28
+ ```
29
+
30
+ ## Table of Contents
31
+
32
+ - [Table of Contents](#table-of-contents)
33
+ - [License](#license)
34
+ - [Get Started](#get-started)
35
+ - [Installation](#installation)
36
+ - [Checkpoints](#checkpoints)
37
+ - [Interactive demo](#interactive-demo)
38
+ - [Interactive demo with docker](#interactive-demo-with-docker)
39
+ - [Usage](#usage)
40
+ - [Training](#training)
41
+ - [Datasets](#datasets)
42
+ - [Demo](#demo)
43
+ - [Our Hyperparameters](#our-hyperparameters)
44
+
45
+ ## License
46
+
47
+ The code is distributed under the CC BY-NC-SA 4.0 License.
48
+ See [LICENSE](LICENSE) for more information.
49
+
50
+ ```python
51
+ # Copyright (C) 2024-present Naver Corporation. All rights reserved.
52
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
53
+ ```
54
+
55
+ ## Get Started
56
+
57
+ ### Installation
58
+
59
+ 1. Clone DUSt3R.
60
+ ```bash
61
+ git clone --recursive https://github.com/naver/dust3r
62
+ cd dust3r
63
+ # if you have already cloned dust3r:
64
+ # git submodule update --init --recursive
65
+ ```
66
+
67
+ 2. Create the environment, here we show an example using conda.
68
+ ```bash
69
+ conda create -n dust3r python=3.11 cmake=3.14.0
70
+ conda activate dust3r
71
+ conda install pytorch torchvision pytorch-cuda=12.1 -c pytorch -c nvidia # use the correct version of cuda for your system
72
+ pip install -r requirements.txt
73
+ # Optional: you can also install additional packages to:
74
+ # - add support for HEIC images
75
+ # - add pyrender, used to render depthmap in some datasets preprocessing
76
+ # - add required packages for visloc.py
77
+ pip install -r requirements_optional.txt
78
+ ```
79
+
80
+ 3. Optional, compile the cuda kernels for RoPE (as in CroCo v2).
81
+ ```bash
82
+ # DUST3R relies on RoPE positional embeddings for which you can compile some cuda kernels for faster runtime.
83
+ cd croco/models/curope/
84
+ python setup.py build_ext --inplace
85
+ cd ../../../
86
+ ```
87
+
88
+ ### Checkpoints
89
+
90
+ You can obtain the checkpoints by two ways:
91
+
92
+ 1) You can use our huggingface_hub integration: the models will be downloaded automatically.
93
+
94
+ 2) Otherwise, We provide several pre-trained models:
95
+
96
+ | Modelname | Training resolutions | Head | Encoder | Decoder |
97
+ |-------------|----------------------|------|---------|---------|
98
+ | [`DUSt3R_ViTLarge_BaseDecoder_224_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_224_linear.pth) | 224x224 | Linear | ViT-L | ViT-B |
99
+ | [`DUSt3R_ViTLarge_BaseDecoder_512_linear.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_linear.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | Linear | ViT-L | ViT-B |
100
+ | [`DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth) | 512x384, 512x336, 512x288, 512x256, 512x160 | DPT | ViT-L | ViT-B |
101
+
102
+ You can check the hyperparameters we used to train these models in the [section: Our Hyperparameters](#our-hyperparameters)
103
+
104
+ To download a specific model, for example `DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth`:
105
+ ```bash
106
+ mkdir -p checkpoints/
107
+ wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth -P checkpoints/
108
+ ```
109
+
110
+ For the checkpoints, make sure to agree to the license of all the public training datasets and base checkpoints we used, in addition to CC-BY-NC-SA 4.0. Again, see [section: Our Hyperparameters](#our-hyperparameters) for details.
111
+
112
+ ### Interactive demo
113
+
114
+ In this demo, you should be able run DUSt3R on your machine to reconstruct a scene.
115
+ First select images that depicts the same scene.
116
+
117
+ You can adjust the global alignment schedule and its number of iterations.
118
+
119
+ > [!NOTE]
120
+ > If you selected one or two images, the global alignment procedure will be skipped (mode=GlobalAlignerMode.PairViewer)
121
+
122
+ Hit "Run" and wait.
123
+ When the global alignment ends, the reconstruction appears.
124
+ Use the slider "min_conf_thr" to show or remove low confidence areas.
125
+
126
+ ```bash
127
+ python3 demo.py --model_name DUSt3R_ViTLarge_BaseDecoder_512_dpt
128
+
129
+ # Use --weights to load a checkpoint from a local file, eg --weights checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
130
+ # Use --image_size to select the correct resolution for the selected checkpoint. 512 (default) or 224
131
+ # Use --local_network to make it accessible on the local network, or --server_name to specify the url manually
132
+ # Use --server_port to change the port, by default it will search for an available port starting at 7860
133
+ # Use --device to use a different device, by default it's "cuda"
134
+ ```
135
+
136
+ ### Interactive demo with docker
137
+
138
+ To run DUSt3R using Docker, including with NVIDIA CUDA support, follow these instructions:
139
+
140
+ 1. **Install Docker**: If not already installed, download and install `docker` and `docker compose` from the [Docker website](https://www.docker.com/get-started).
141
+
142
+ 2. **Install NVIDIA Docker Toolkit**: For GPU support, install the NVIDIA Docker toolkit from the [Nvidia website](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html).
143
+
144
+ 3. **Build the Docker image and run it**: `cd` into the `./docker` directory and run the following commands:
145
+
146
+ ```bash
147
+ cd docker
148
+ bash run.sh --with-cuda --model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt"
149
+ ```
150
+
151
+ Or if you want to run the demo without CUDA support, run the following command:
152
+
153
+ ```bash
154
+ cd docker
155
+ bash run.sh --model_name="DUSt3R_ViTLarge_BaseDecoder_512_dpt"
156
+ ```
157
+
158
+ By default, `demo.py` is lanched with the option `--local_network`.
159
+ Visit `http://localhost:7860/` to access the web UI (or replace `localhost` with the machine's name to access it from the network).
160
+
161
+ `run.sh` will launch docker-compose using either the [docker-compose-cuda.yml](docker/docker-compose-cuda.yml) or [docker-compose-cpu.ym](docker/docker-compose-cpu.yml) config file, then it starts the demo using [entrypoint.sh](docker/files/entrypoint.sh).
162
+
163
+
164
+ ![demo](assets/demo.jpg)
165
+
166
+ ## Usage
167
+
168
+ ```python
169
+ from dust3r.inference import inference
170
+ from dust3r.model import AsymmetricCroCo3DStereo
171
+ from dust3r.utils.image import load_images
172
+ from dust3r.image_pairs import make_pairs
173
+ from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
174
+
175
+ if __name__ == '__main__':
176
+ device = 'cuda'
177
+ batch_size = 1
178
+ schedule = 'cosine'
179
+ lr = 0.01
180
+ niter = 300
181
+
182
+ model_name = "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
183
+ # you can put the path to a local checkpoint in model_name if needed
184
+ model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)
185
+ # load_images can take a list of images or a directory
186
+ images = load_images(['croco/assets/Chateau1.png', 'croco/assets/Chateau2.png'], size=512)
187
+ pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
188
+ output = inference(pairs, model, device, batch_size=batch_size)
189
+
190
+ # at this stage, you have the raw dust3r predictions
191
+ view1, pred1 = output['view1'], output['pred1']
192
+ view2, pred2 = output['view2'], output['pred2']
193
+ # here, view1, pred1, view2, pred2 are dicts of lists of len(2)
194
+ # -> because we symmetrize we have (im1, im2) and (im2, im1) pairs
195
+ # in each view you have:
196
+ # an integer image identifier: view1['idx'] and view2['idx']
197
+ # the img: view1['img'] and view2['img']
198
+ # the image shape: view1['true_shape'] and view2['true_shape']
199
+ # an instance string output by the dataloader: view1['instance'] and view2['instance']
200
+ # pred1 and pred2 contains the confidence values: pred1['conf'] and pred2['conf']
201
+ # pred1 contains 3D points for view1['img'] in view1['img'] space: pred1['pts3d']
202
+ # pred2 contains 3D points for view2['img'] in view1['img'] space: pred2['pts3d_in_other_view']
203
+
204
+ # next we'll use the global_aligner to align the predictions
205
+ # depending on your task, you may be fine with the raw output and not need it
206
+ # with only two input images, you could use GlobalAlignerMode.PairViewer: it would just convert the output
207
+ # if using GlobalAlignerMode.PairViewer, no need to run compute_global_alignment
208
+ scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer)
209
+ loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
210
+
211
+ # retrieve useful values from scene:
212
+ imgs = scene.imgs
213
+ focals = scene.get_focals()
214
+ poses = scene.get_im_poses()
215
+ pts3d = scene.get_pts3d()
216
+ confidence_masks = scene.get_masks()
217
+
218
+ # visualize reconstruction
219
+ scene.show()
220
+
221
+ # find 2D-2D matches between the two images
222
+ from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
223
+ pts2d_list, pts3d_list = [], []
224
+ for i in range(2):
225
+ conf_i = confidence_masks[i].cpu().numpy()
226
+ pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i]) # imgs[i].shape[:2] = (H, W)
227
+ pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
228
+ reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
229
+ print(f'found {num_matches} matches')
230
+ matches_im1 = pts2d_list[1][reciprocal_in_P2]
231
+ matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
232
+
233
+ # visualize a few matches
234
+ import numpy as np
235
+ from matplotlib import pyplot as pl
236
+ n_viz = 10
237
+ match_idx_to_viz = np.round(np.linspace(0, num_matches-1, n_viz)).astype(int)
238
+ viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz]
239
+
240
+ H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
241
+ img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
242
+ img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0)
243
+ img = np.concatenate((img0, img1), axis=1)
244
+ pl.figure()
245
+ pl.imshow(img)
246
+ cmap = pl.get_cmap('jet')
247
+ for i in range(n_viz):
248
+ (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T
249
+ pl.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False)
250
+ pl.show(block=True)
251
+
252
+ ```
253
+ ![matching example on croco pair](assets/matching.jpg)
254
+
255
+ ## Training
256
+
257
+ In this section, we present a short demonstration to get started with training DUSt3R.
258
+
259
+ ### Datasets
260
+ At this moment, we have added the following training datasets:
261
+ - [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE)
262
+ - [ARKitScenes](https://github.com/apple/ARKitScenes) - [Creative Commons Attribution-NonCommercial-ShareAlike 4.0](https://github.com/apple/ARKitScenes/tree/main?tab=readme-ov-file#license)
263
+ - [ScanNet++](https://kaldir.vc.in.tum.de/scannetpp/) - [non-commercial research and educational purposes](https://kaldir.vc.in.tum.de/scannetpp/static/scannetpp-terms-of-use.pdf)
264
+ - [BlendedMVS](https://github.com/YoYo000/BlendedMVS) - [Creative Commons Attribution 4.0 International License](https://creativecommons.org/licenses/by/4.0/)
265
+ - [WayMo Open dataset](https://github.com/waymo-research/waymo-open-dataset) - [Non-Commercial Use](https://waymo.com/open/terms/)
266
+ - [Habitat-Sim](https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md)
267
+ - [MegaDepth](https://www.cs.cornell.edu/projects/megadepth/)
268
+ - [StaticThings3D](https://github.com/lmb-freiburg/robustmvd/blob/master/rmvd/data/README.md#staticthings3d)
269
+ - [WildRGB-D](https://github.com/wildrgbd/wildrgbd/)
270
+
271
+ For each dataset, we provide a preprocessing script in the `datasets_preprocess` directory and an archive containing the list of pairs when needed.
272
+ You have to download the datasets yourself from their official sources, agree to their license, download our list of pairs, and run the preprocessing script.
273
+
274
+ Links:
275
+
276
+ [ARKitScenes pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/arkitscenes_pairs.zip)
277
+ [ScanNet++ pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/scannetpp_pairs.zip)
278
+ [BlendedMVS pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/blendedmvs_pairs.npy)
279
+ [WayMo Open dataset pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/waymo_pairs.npz)
280
+ [Habitat metadata](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz)
281
+ [MegaDepth pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/megadepth_pairs.npz)
282
+ [StaticThings3D pairs](https://download.europe.naverlabs.com/ComputerVision/DUSt3R/staticthings_pairs.npy)
283
+
284
+ > [!NOTE]
285
+ > They are not strictly equivalent to what was used to train DUSt3R, but they should be close enough.
286
+
287
+ ### Demo
288
+ For this training demo, we're going to download and prepare a subset of [CO3Dv2](https://github.com/facebookresearch/co3d) - [Creative Commons Attribution-NonCommercial 4.0 International](https://github.com/facebookresearch/co3d/blob/main/LICENSE) and launch the training code on it.
289
+ The demo model will be trained for a few epochs on a very small dataset.
290
+ It will not be very good.
291
+
292
+ ```bash
293
+ # download and prepare the co3d subset
294
+ mkdir -p data/co3d_subset
295
+ cd data/co3d_subset
296
+ git clone https://github.com/facebookresearch/co3d
297
+ cd co3d
298
+ python3 ./co3d/download_dataset.py --download_folder ../ --single_sequence_subset
299
+ rm ../*.zip
300
+ cd ../../..
301
+
302
+ python3 datasets_preprocess/preprocess_co3d.py --co3d_dir data/co3d_subset --output_dir data/co3d_subset_processed --single_sequence_subset
303
+
304
+ # download the pretrained croco v2 checkpoint
305
+ mkdir -p checkpoints/
306
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth -P checkpoints/
307
+
308
+ # the training of dust3r is done in 3 steps.
309
+ # for this example we'll do fewer epochs, for the actual hyperparameters we used in the paper, see the next section: "Our Hyperparameters"
310
+ # step 1 - train dust3r for 224 resolution
311
+ torchrun --nproc_per_node=4 train.py \
312
+ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter)" \
313
+ --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=224, seed=777)" \
314
+ --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
315
+ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
316
+ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
317
+ --pretrained "checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \
318
+ --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 16 --accum_iter 1 \
319
+ --save_freq 1 --keep_freq 5 --eval_freq 1 \
320
+ --output_dir "checkpoints/dust3r_demo_224"
321
+
322
+ # step 2 - train dust3r for 512 resolution
323
+ torchrun --nproc_per_node=4 train.py \
324
+ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
325
+ --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
326
+ --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
327
+ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
328
+ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
329
+ --pretrained "checkpoints/dust3r_demo_224/checkpoint-best.pth" \
330
+ --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 4 --accum_iter 4 \
331
+ --save_freq 1 --keep_freq 5 --eval_freq 1 \
332
+ --output_dir "checkpoints/dust3r_demo_512"
333
+
334
+ # step 3 - train dust3r for 512 resolution with dpt
335
+ torchrun --nproc_per_node=4 train.py \
336
+ --train_dataset "1000 @ Co3d(split='train', ROOT='data/co3d_subset_processed', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter)" \
337
+ --test_dataset "100 @ Co3d(split='test', ROOT='data/co3d_subset_processed', resolution=(512,384), seed=777)" \
338
+ --model "AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
339
+ --train_criterion "ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
340
+ --test_criterion "Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
341
+ --pretrained "checkpoints/dust3r_demo_512/checkpoint-best.pth" \
342
+ --lr 0.0001 --min_lr 1e-06 --warmup_epochs 1 --epochs 10 --batch_size 2 --accum_iter 8 \
343
+ --save_freq 1 --keep_freq 5 --eval_freq 1 --disable_cudnn_benchmark \
344
+ --output_dir "checkpoints/dust3r_demo_512dpt"
345
+
346
+ ```
347
+
348
+ ### Our Hyperparameters
349
+
350
+ Here are the commands we used for training the models:
351
+
352
+ ```bash
353
+ # NOTE: ROOT path omitted for datasets
354
+ # 224 linear
355
+ torchrun --nproc_per_node 8 train.py \
356
+ --train_dataset=" + 100_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ BlendedMVS(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ MegaDepth(split='train', aug_crop=16, resolution=224, transform=ColorJitter) + 100_000 @ ARKitScenes(aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=224, transform=ColorJitter) + 100_000 @ ScanNetpp(split='train', aug_crop=256, resolution=224, transform=ColorJitter) + 100_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=224, transform=ColorJitter) " \
357
+ --test_dataset=" Habitat(1_000, split='val', resolution=224, seed=777) + 1_000 @ BlendedMVS(split='val', resolution=224, seed=777) + 1_000 @ MegaDepth(split='val', resolution=224, seed=777) + 1_000 @ Co3d(split='test', mask_bg='rand', resolution=224, seed=777) " \
358
+ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
359
+ --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
360
+ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', img_size=(224, 224), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
361
+ --pretrained="checkpoints/CroCo_V2_ViTLarge_BaseDecoder.pth" \
362
+ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=10 --epochs=100 --batch_size=16 --accum_iter=1 \
363
+ --save_freq=5 --keep_freq=10 --eval_freq=1 \
364
+ --output_dir="checkpoints/dust3r_224"
365
+
366
+ # 512 linear
367
+ torchrun --nproc_per_node 8 train.py \
368
+ --train_dataset=" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
369
+ --test_dataset=" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) " \
370
+ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
371
+ --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
372
+ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='linear', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
373
+ --pretrained="checkpoints/dust3r_224/checkpoint-best.pth" \
374
+ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=20 --epochs=100 --batch_size=4 --accum_iter=2 \
375
+ --save_freq=10 --keep_freq=10 --eval_freq=1 --print_freq=10 \
376
+ --output_dir="checkpoints/dust3r_512"
377
+
378
+ # 512 dpt
379
+ torchrun --nproc_per_node 8 train.py \
380
+ --train_dataset=" + 10_000 @ Habitat(1_000_000, split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ BlendedMVS(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ MegaDepth(split='train', aug_crop=16, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ARKitScenes(aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ Co3d(split='train', aug_crop=16, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ StaticThings3D(aug_crop=256, mask_bg='rand', resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ ScanNetpp(split='train', aug_crop=256, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) + 10_000 @ InternalUnreleasedDataset(aug_crop=128, resolution=[(512, 384), (512, 336), (512, 288), (512, 256), (512, 160)], transform=ColorJitter) " \
381
+ --test_dataset=" Habitat(1_000, split='val', resolution=(512,384), seed=777) + 1_000 @ BlendedMVS(split='val', resolution=(512,384), seed=777) + 1_000 @ MegaDepth(split='val', resolution=(512,336), seed=777) + 1_000 @ Co3d(split='test', resolution=(512,384), seed=777) " \
382
+ --train_criterion="ConfLoss(Regr3D(L21, norm_mode='avg_dis'), alpha=0.2)" \
383
+ --test_criterion="Regr3D_ScaleShiftInv(L21, gt_scale=True)" \
384
+ --model="AsymmetricCroCo3DStereo(pos_embed='RoPE100', patch_embed_cls='ManyAR_PatchEmbed', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_depth=12, dec_num_heads=12)" \
385
+ --pretrained="checkpoints/dust3r_512/checkpoint-best.pth" \
386
+ --lr=0.0001 --min_lr=1e-06 --warmup_epochs=15 --epochs=90 --batch_size=4 --accum_iter=2 \
387
+ --save_freq=5 --keep_freq=10 --eval_freq=1 --print_freq=10 --disable_cudnn_benchmark \
388
+ --output_dir="checkpoints/dust3r_512dpt"
389
+
390
+ ```
third_party/dust3r/croco/LICENSE ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo, Copyright (c) 2022-present Naver Corporation, is licensed under the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 license.
2
+
3
+ A summary of the CC BY-NC-SA 4.0 license is located here:
4
+ https://creativecommons.org/licenses/by-nc-sa/4.0/
5
+
6
+ The CC BY-NC-SA 4.0 license is located here:
7
+ https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
8
+
9
+
10
+ SEE NOTICE BELOW WITH RESPECT TO THE FILE: models/pos_embed.py, models/blocks.py
11
+
12
+ ***************************
13
+
14
+ NOTICE WITH RESPECT TO THE FILE: models/pos_embed.py
15
+
16
+ This software is being redistributed in a modifiled form. The original form is available here:
17
+
18
+ https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
+
20
+ This software in this file incorporates parts of the following software available here:
21
+
22
+ Transformer: https://github.com/tensorflow/models/blob/master/official/legacy/transformer/model_utils.py
23
+ available under the following license: https://github.com/tensorflow/models/blob/master/LICENSE
24
+
25
+ MoCo v3: https://github.com/facebookresearch/moco-v3
26
+ available under the following license: https://github.com/facebookresearch/moco-v3/blob/main/LICENSE
27
+
28
+ DeiT: https://github.com/facebookresearch/deit
29
+ available under the following license: https://github.com/facebookresearch/deit/blob/main/LICENSE
30
+
31
+
32
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
33
+
34
+ https://github.com/facebookresearch/mae/blob/main/LICENSE
35
+
36
+ Attribution-NonCommercial 4.0 International
37
+
38
+ ***************************
39
+
40
+ NOTICE WITH RESPECT TO THE FILE: models/blocks.py
41
+
42
+ This software is being redistributed in a modifiled form. The original form is available here:
43
+
44
+ https://github.com/rwightman/pytorch-image-models
45
+
46
+ ORIGINAL COPYRIGHT NOTICE AND PERMISSION NOTICE AVAILABLE HERE IS REPRODUCE BELOW:
47
+
48
+ https://github.com/rwightman/pytorch-image-models/blob/master/LICENSE
49
+
50
+ Apache License
51
+ Version 2.0, January 2004
52
+ http://www.apache.org/licenses/
third_party/dust3r/croco/NOTICE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CroCo
2
+ Copyright 2022-present NAVER Corp.
3
+
4
+ This project contains subcomponents with separate copyright notices and license terms.
5
+ Your use of the source code for these subcomponents is subject to the terms and conditions of the following licenses.
6
+
7
+ ====
8
+
9
+ facebookresearch/mae
10
+ https://github.com/facebookresearch/mae
11
+
12
+ Attribution-NonCommercial 4.0 International
13
+
14
+ ====
15
+
16
+ rwightman/pytorch-image-models
17
+ https://github.com/rwightman/pytorch-image-models
18
+
19
+ Apache License
20
+ Version 2.0, January 2004
21
+ http://www.apache.org/licenses/
third_party/dust3r/croco/README.MD ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CroCo + CroCo v2 / CroCo-Stereo / CroCo-Flow
2
+
3
+ [[`CroCo arXiv`](https://arxiv.org/abs/2210.10716)] [[`CroCo v2 arXiv`](https://arxiv.org/abs/2211.10408)] [[`project page and demo`](https://croco.europe.naverlabs.com/)]
4
+
5
+ This repository contains the code for our CroCo model presented in our NeurIPS'22 paper [CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion](https://openreview.net/pdf?id=wZEfHUM5ri) and its follow-up extension published at ICCV'23 [Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow](https://openaccess.thecvf.com/content/ICCV2023/html/Weinzaepfel_CroCo_v2_Improved_Cross-view_Completion_Pre-training_for_Stereo_Matching_and_ICCV_2023_paper.html), refered to as CroCo v2:
6
+
7
+ ![image](assets/arch.jpg)
8
+
9
+ ```bibtex
10
+ @inproceedings{croco,
11
+ title={{CroCo: Self-Supervised Pre-training for 3D Vision Tasks by Cross-View Completion}},
12
+ author={{Weinzaepfel, Philippe and Leroy, Vincent and Lucas, Thomas and Br\'egier, Romain and Cabon, Yohann and Arora, Vaibhav and Antsfeld, Leonid and Chidlovskii, Boris and Csurka, Gabriela and Revaud J\'er\^ome}},
13
+ booktitle={{NeurIPS}},
14
+ year={2022}
15
+ }
16
+
17
+ @inproceedings{croco_v2,
18
+ title={{CroCo v2: Improved Cross-view Completion Pre-training for Stereo Matching and Optical Flow}},
19
+ author={Weinzaepfel, Philippe and Lucas, Thomas and Leroy, Vincent and Cabon, Yohann and Arora, Vaibhav and Br{\'e}gier, Romain and Csurka, Gabriela and Antsfeld, Leonid and Chidlovskii, Boris and Revaud, J{\'e}r{\^o}me},
20
+ booktitle={ICCV},
21
+ year={2023}
22
+ }
23
+ ```
24
+
25
+ ## License
26
+
27
+ The code is distributed under the CC BY-NC-SA 4.0 License. See [LICENSE](LICENSE) for more information.
28
+ Some components are based on code from [MAE](https://github.com/facebookresearch/mae) released under the CC BY-NC-SA 4.0 License and [timm](https://github.com/rwightman/pytorch-image-models) released under the Apache 2.0 License.
29
+ Some components for stereo matching and optical flow are based on code from [unimatch](https://github.com/autonomousvision/unimatch) released under the MIT license.
30
+
31
+ ## Preparation
32
+
33
+ 1. Install dependencies on a machine with a NVidia GPU using e.g. conda. Note that `habitat-sim` is required only for the interactive demo and the synthetic pre-training data generation. If you don't plan to use it, you can ignore the line installing it and use a more recent python version.
34
+
35
+ ```bash
36
+ conda create -n croco python=3.7 cmake=3.14.0
37
+ conda activate croco
38
+ conda install habitat-sim headless -c conda-forge -c aihabitat
39
+ conda install pytorch torchvision -c pytorch
40
+ conda install notebook ipykernel matplotlib
41
+ conda install ipywidgets widgetsnbextension
42
+ conda install scikit-learn tqdm quaternion opencv # only for pretraining / habitat data generation
43
+
44
+ ```
45
+
46
+ 2. Compile cuda kernels for RoPE
47
+
48
+ CroCo v2 relies on RoPE positional embeddings for which you need to compile some cuda kernels.
49
+ ```bash
50
+ cd models/curope/
51
+ python setup.py build_ext --inplace
52
+ cd ../../
53
+ ```
54
+
55
+ This can be a bit long as we compile for all cuda architectures, feel free to update L9 of `models/curope/setup.py` to compile for specific architectures only.
56
+ You might also need to set the environment `CUDA_HOME` in case you use a custom cuda installation.
57
+
58
+ In case you cannot provide, we also provide a slow pytorch version, which will be automatically loaded.
59
+
60
+ 3. Download pre-trained model
61
+
62
+ We provide several pre-trained models:
63
+
64
+ | modelname | pre-training data | pos. embed. | Encoder | Decoder |
65
+ |------------------------------------------------------------------------------------------------------------------------------------|-------------------|-------------|---------|---------|
66
+ | [`CroCo.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth) | Habitat | cosine | ViT-B | Small |
67
+ | [`CroCo_V2_ViTBase_SmallDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_SmallDecoder.pth) | Habitat + real | RoPE | ViT-B | Small |
68
+ | [`CroCo_V2_ViTBase_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTBase_BaseDecoder.pth) | Habitat + real | RoPE | ViT-B | Base |
69
+ | [`CroCo_V2_ViTLarge_BaseDecoder.pth`](https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo_V2_ViTLarge_BaseDecoder.pth) | Habitat + real | RoPE | ViT-L | Base |
70
+
71
+ To download a specific model, i.e., the first one (`CroCo.pth`)
72
+ ```bash
73
+ mkdir -p pretrained_models/
74
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/CroCo.pth -P pretrained_models/
75
+ ```
76
+
77
+ ## Reconstruction example
78
+
79
+ Simply run after downloading the `CroCo_V2_ViTLarge_BaseDecoder` pretrained model (or update the corresponding line in `demo.py`)
80
+ ```bash
81
+ python demo.py
82
+ ```
83
+
84
+ ## Interactive demonstration of cross-view completion reconstruction on the Habitat simulator
85
+
86
+ First download the test scene from Habitat:
87
+ ```bash
88
+ python -m habitat_sim.utils.datasets_download --uids habitat_test_scenes --data-path habitat-sim-data/
89
+ ```
90
+
91
+ Then, run the Notebook demo `interactive_demo.ipynb`.
92
+
93
+ In this demo, you should be able to sample a random reference viewpoint from an [Habitat](https://github.com/facebookresearch/habitat-sim) test scene. Use the sliders to change viewpoint and select a masked target view to reconstruct using CroCo.
94
+ ![croco_interactive_demo](https://user-images.githubusercontent.com/1822210/200516576-7937bc6a-55f8-49ed-8618-3ddf89433ea4.jpg)
95
+
96
+ ## Pre-training
97
+
98
+ ### CroCo
99
+
100
+ To pre-train CroCo, please first generate the pre-training data from the Habitat simulator, following the instructions in [datasets/habitat_sim/README.MD](datasets/habitat_sim/README.MD) and then run the following command:
101
+ ```
102
+ torchrun --nproc_per_node=4 pretrain.py --output_dir ./output/pretraining/
103
+ ```
104
+
105
+ Our CroCo pre-training was launched on a single server with 4 GPUs.
106
+ It should take around 10 days with A100 or 15 days with V100 to do the 400 pre-training epochs, but decent performances are obtained earlier in training.
107
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
108
+ The first run can take a few minutes to start, to parse all available pre-training pairs.
109
+
110
+ ### CroCo v2
111
+
112
+ For CroCo v2 pre-training, in addition to the generation of the pre-training data from the Habitat simulator above, please pre-extract the crops from the real datasets following the instructions in [datasets/crops/README.MD](datasets/crops/README.MD).
113
+ Then, run the following command for the largest model (ViT-L encoder, Base decoder):
114
+ ```
115
+ torchrun --nproc_per_node=8 pretrain.py --model "CroCoNet(enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, dec_embed_dim=768, dec_num_heads=12, dec_depth=12, pos_embed='RoPE100')" --dataset "habitat_release+ARKitScenes+MegaDepth+3DStreetView+IndoorVL" --warmup_epochs 12 --max_epoch 125 --epochs 250 --amp 0 --keep_freq 5 --output_dir ./output/pretraining_crocov2/
116
+ ```
117
+
118
+ Our CroCo v2 pre-training was launched on a single server with 8 GPUs for the largest model, and on a single server with 4 GPUs for the smaller ones, keeping a batch size of 64 per gpu in all cases.
119
+ The largest model should take around 12 days on A100.
120
+ Note that, while the code contains the same scaling rule of the learning rate as MAE when changing the effective batch size, we did not experimented if it is valid in our case.
121
+
122
+ ## Stereo matching and Optical flow downstream tasks
123
+
124
+ For CroCo-Stereo and CroCo-Flow, please refer to [stereoflow/README.MD](stereoflow/README.MD).
third_party/dust3r/croco/datasets/__init__.py ADDED
File without changes
third_party/dust3r/croco/datasets/crops/README.MD ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of crops from the real datasets
2
+
3
+ The instructions below allow to generate the crops used for pre-training CroCo v2 from the following real-world datasets: ARKitScenes, MegaDepth, 3DStreetView and IndoorVL.
4
+
5
+ ### Download the metadata of the crops to generate
6
+
7
+ First, download the metadata and put them in `./data/`:
8
+ ```
9
+ mkdir -p data
10
+ cd data/
11
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/crop_metadata.zip
12
+ unzip crop_metadata.zip
13
+ rm crop_metadata.zip
14
+ cd ..
15
+ ```
16
+
17
+ ### Prepare the original datasets
18
+
19
+ Second, download the original datasets in `./data/original_datasets/`.
20
+ ```
21
+ mkdir -p data/original_datasets
22
+ ```
23
+
24
+ ##### ARKitScenes
25
+
26
+ Download the `raw` dataset from https://github.com/apple/ARKitScenes/blob/main/DATA.md and put it in `./data/original_datasets/ARKitScenes/`.
27
+ The resulting file structure should be like:
28
+ ```
29
+ ./data/original_datasets/ARKitScenes/
30
+ └───Training
31
+ └───40753679
32
+ │ │ ultrawide
33
+ │ │ ...
34
+ └───40753686
35
+
36
+ ...
37
+ ```
38
+
39
+ ##### MegaDepth
40
+
41
+ Download `MegaDepth v1 Dataset` from https://www.cs.cornell.edu/projects/megadepth/ and put it in `./data/original_datasets/MegaDepth/`.
42
+ The resulting file structure should be like:
43
+
44
+ ```
45
+ ./data/original_datasets/MegaDepth/
46
+ └───0000
47
+ │ └───images
48
+ │ │ │ 1000557903_87fa96b8a4_o.jpg
49
+ │ │ └ ...
50
+ │ └─── ...
51
+ └───0001
52
+ │ │
53
+ │ └ ...
54
+ └─── ...
55
+ ```
56
+
57
+ ##### 3DStreetView
58
+
59
+ Download `3D_Street_View` dataset from https://github.com/amir32002/3D_Street_View and put it in `./data/original_datasets/3DStreetView/`.
60
+ The resulting file structure should be like:
61
+
62
+ ```
63
+ ./data/original_datasets/3DStreetView/
64
+ └───dataset_aligned
65
+ │ └───0002
66
+ │ │ │ 0000002_0000001_0000002_0000001.jpg
67
+ │ │ └ ...
68
+ │ └─── ...
69
+ └───dataset_unaligned
70
+ │ └───0003
71
+ │ │ │ 0000003_0000001_0000002_0000001.jpg
72
+ │ │ └ ...
73
+ │ └─── ...
74
+ ```
75
+
76
+ ##### IndoorVL
77
+
78
+ Download the `IndoorVL` datasets using [Kapture](https://github.com/naver/kapture).
79
+
80
+ ```
81
+ pip install kapture
82
+ mkdir -p ./data/original_datasets/IndoorVL
83
+ cd ./data/original_datasets/IndoorVL
84
+ kapture_download_dataset.py update
85
+ kapture_download_dataset.py install "HyundaiDepartmentStore_*"
86
+ kapture_download_dataset.py install "GangnamStation_*"
87
+ cd -
88
+ ```
89
+
90
+ ### Extract the crops
91
+
92
+ Now, extract the crops for each of the dataset:
93
+ ```
94
+ for dataset in ARKitScenes MegaDepth 3DStreetView IndoorVL;
95
+ do
96
+ python3 datasets/crops/extract_crops_from_images.py --crops ./data/crop_metadata/${dataset}/crops_release.txt --root-dir ./data/original_datasets/${dataset}/ --output-dir ./data/${dataset}_crops/ --imsize 256 --nthread 8 --max-subdir-levels 5 --ideal-number-pairs-in-dir 500;
97
+ done
98
+ ```
99
+
100
+ ##### Note for IndoorVL
101
+
102
+ Due to some legal issues, we can only release 144,228 pairs out of the 1,593,689 pairs used in the paper.
103
+ To account for it in terms of number of pre-training iterations, the pre-training command in this repository uses 125 training epochs including 12 warm-up epochs and learning rate cosine schedule of 250, instead of 100, 10 and 200 respectively.
104
+ The impact on the performance is negligible.
third_party/dust3r/croco/datasets/crops/extract_crops_from_images.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Extracting crops for pre-training
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import argparse
10
+ from tqdm import tqdm
11
+ from PIL import Image
12
+ import functools
13
+ from multiprocessing import Pool
14
+ import math
15
+
16
+
17
+ def arg_parser():
18
+ parser = argparse.ArgumentParser('Generate cropped image pairs from image crop list')
19
+
20
+ parser.add_argument('--crops', type=str, required=True, help='crop file')
21
+ parser.add_argument('--root-dir', type=str, required=True, help='root directory')
22
+ parser.add_argument('--output-dir', type=str, required=True, help='output directory')
23
+ parser.add_argument('--imsize', type=int, default=256, help='size of the crops')
24
+ parser.add_argument('--nthread', type=int, required=True, help='number of simultaneous threads')
25
+ parser.add_argument('--max-subdir-levels', type=int, default=5, help='maximum number of subdirectories')
26
+ parser.add_argument('--ideal-number-pairs-in-dir', type=int, default=500, help='number of pairs stored in a dir')
27
+ return parser
28
+
29
+
30
+ def main(args):
31
+ listing_path = os.path.join(args.output_dir, 'listing.txt')
32
+
33
+ print(f'Loading list of crops ... ({args.nthread} threads)')
34
+ crops, num_crops_to_generate = load_crop_file(args.crops)
35
+
36
+ print(f'Preparing jobs ({len(crops)} candidate image pairs)...')
37
+ num_levels = min(math.ceil(math.log(num_crops_to_generate, args.ideal_number_pairs_in_dir)), args.max_subdir_levels)
38
+ num_pairs_in_dir = math.ceil(num_crops_to_generate ** (1/num_levels))
39
+
40
+ jobs = prepare_jobs(crops, num_levels, num_pairs_in_dir)
41
+ del crops
42
+
43
+ os.makedirs(args.output_dir, exist_ok=True)
44
+ mmap = Pool(args.nthread).imap_unordered if args.nthread > 1 else map
45
+ call = functools.partial(save_image_crops, args)
46
+
47
+ print(f"Generating cropped images to {args.output_dir} ...")
48
+ with open(listing_path, 'w') as listing:
49
+ listing.write('# pair_path\n')
50
+ for results in tqdm(mmap(call, jobs), total=len(jobs)):
51
+ for path in results:
52
+ listing.write(f'{path}\n')
53
+ print('Finished writing listing to', listing_path)
54
+
55
+
56
+ def load_crop_file(path):
57
+ data = open(path).read().splitlines()
58
+ pairs = []
59
+ num_crops_to_generate = 0
60
+ for line in tqdm(data):
61
+ if line.startswith('#'):
62
+ continue
63
+ line = line.split(', ')
64
+ if len(line) < 8:
65
+ img1, img2, rotation = line
66
+ pairs.append((img1, img2, int(rotation), []))
67
+ else:
68
+ l1, r1, t1, b1, l2, r2, t2, b2 = map(int, line)
69
+ rect1, rect2 = (l1, t1, r1, b1), (l2, t2, r2, b2)
70
+ pairs[-1][-1].append((rect1, rect2))
71
+ num_crops_to_generate += 1
72
+ return pairs, num_crops_to_generate
73
+
74
+
75
+ def prepare_jobs(pairs, num_levels, num_pairs_in_dir):
76
+ jobs = []
77
+ powers = [num_pairs_in_dir**level for level in reversed(range(num_levels))]
78
+
79
+ def get_path(idx):
80
+ idx_array = []
81
+ d = idx
82
+ for level in range(num_levels - 1):
83
+ idx_array.append(idx // powers[level])
84
+ idx = idx % powers[level]
85
+ idx_array.append(d)
86
+ return '/'.join(map(lambda x: hex(x)[2:], idx_array))
87
+
88
+ idx = 0
89
+ for pair_data in tqdm(pairs):
90
+ img1, img2, rotation, crops = pair_data
91
+ if -60 <= rotation and rotation <= 60:
92
+ rotation = 0 # most likely not a true rotation
93
+ paths = [get_path(idx + k) for k in range(len(crops))]
94
+ idx += len(crops)
95
+ jobs.append(((img1, img2), rotation, crops, paths))
96
+ return jobs
97
+
98
+
99
+ def load_image(path):
100
+ try:
101
+ return Image.open(path).convert('RGB')
102
+ except Exception as e:
103
+ print('skipping', path, e)
104
+ raise OSError()
105
+
106
+
107
+ def save_image_crops(args, data):
108
+ # load images
109
+ img_pair, rot, crops, paths = data
110
+ try:
111
+ img1, img2 = [load_image(os.path.join(args.root_dir, impath)) for impath in img_pair]
112
+ except OSError as e:
113
+ return []
114
+
115
+ def area(sz):
116
+ return sz[0] * sz[1]
117
+
118
+ tgt_size = (args.imsize, args.imsize)
119
+
120
+ def prepare_crop(img, rect, rot=0):
121
+ # actual crop
122
+ img = img.crop(rect)
123
+
124
+ # resize to desired size
125
+ interp = Image.Resampling.LANCZOS if area(img.size) > 4*area(tgt_size) else Image.Resampling.BICUBIC
126
+ img = img.resize(tgt_size, resample=interp)
127
+
128
+ # rotate the image
129
+ rot90 = (round(rot/90) % 4) * 90
130
+ if rot90 == 90:
131
+ img = img.transpose(Image.Transpose.ROTATE_90)
132
+ elif rot90 == 180:
133
+ img = img.transpose(Image.Transpose.ROTATE_180)
134
+ elif rot90 == 270:
135
+ img = img.transpose(Image.Transpose.ROTATE_270)
136
+ return img
137
+
138
+ results = []
139
+ for (rect1, rect2), path in zip(crops, paths):
140
+ crop1 = prepare_crop(img1, rect1)
141
+ crop2 = prepare_crop(img2, rect2, rot)
142
+
143
+ fullpath1 = os.path.join(args.output_dir, path+'_1.jpg')
144
+ fullpath2 = os.path.join(args.output_dir, path+'_2.jpg')
145
+ os.makedirs(os.path.dirname(fullpath1), exist_ok=True)
146
+
147
+ assert not os.path.isfile(fullpath1), fullpath1
148
+ assert not os.path.isfile(fullpath2), fullpath2
149
+ crop1.save(fullpath1)
150
+ crop2.save(fullpath2)
151
+ results.append(path)
152
+
153
+ return results
154
+
155
+
156
+ if __name__ == '__main__':
157
+ args = arg_parser().parse_args()
158
+ main(args)
159
+
third_party/dust3r/croco/datasets/habitat_sim/README.MD ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Generation of synthetic image pairs using Habitat-Sim
2
+
3
+ These instructions allow to generate pre-training pairs from the Habitat simulator.
4
+ As we did not save metadata of the pairs used in the original paper, they are not strictly the same, but these data use the same setting and are equivalent.
5
+
6
+ ### Download Habitat-Sim scenes
7
+ Download Habitat-Sim scenes:
8
+ - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
9
+ - We used scenes from the HM3D, habitat-test-scenes, Replica, ReplicaCad and ScanNet datasets.
10
+ - Please put the scenes under `./data/habitat-sim-data/scene_datasets/` following the structure below, or update manually paths in `paths.py`.
11
+ ```
12
+ ./data/
13
+ └──habitat-sim-data/
14
+ └──scene_datasets/
15
+ ├──hm3d/
16
+ ├──gibson/
17
+ ├──habitat-test-scenes/
18
+ ├──replica_cad_baked_lighting/
19
+ ├──replica_cad/
20
+ ├──ReplicaDataset/
21
+ └──scannet/
22
+ ```
23
+
24
+ ### Image pairs generation
25
+ We provide metadata to generate reproducible images pairs for pretraining and validation.
26
+ Experiments described in the paper used similar data, but whose generation was not reproducible at the time.
27
+
28
+ Specifications:
29
+ - 256x256 resolution images, with 60 degrees field of view .
30
+ - Up to 1000 image pairs per scene.
31
+ - Number of scenes considered/number of images pairs per dataset:
32
+ - Scannet: 1097 scenes / 985 209 pairs
33
+ - HM3D:
34
+ - hm3d/train: 800 / 800k pairs
35
+ - hm3d/val: 100 scenes / 100k pairs
36
+ - hm3d/minival: 10 scenes / 10k pairs
37
+ - habitat-test-scenes: 3 scenes / 3k pairs
38
+ - replica_cad_baked_lighting: 13 scenes / 13k pairs
39
+
40
+ - Scenes from hm3d/val and hm3d/minival pairs were not used for the pre-training but kept for validation purposes.
41
+
42
+ Download metadata and extract it:
43
+ ```bash
44
+ mkdir -p data/habitat_release_metadata/
45
+ cd data/habitat_release_metadata/
46
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/data/habitat_release_metadata/multiview_habitat_metadata.tar.gz
47
+ tar -xvf multiview_habitat_metadata.tar.gz
48
+ cd ../..
49
+ # Location of the metadata
50
+ METADATA_DIR="./data/habitat_release_metadata/multiview_habitat_metadata"
51
+ ```
52
+
53
+ Generate image pairs from metadata:
54
+ - The following command will print a list of commandlines to generate image pairs for each scene:
55
+ ```bash
56
+ # Target output directory
57
+ PAIRS_DATASET_DIR="./data/habitat_release/"
58
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR
59
+ ```
60
+ - One can launch multiple of such commands in parallel e.g. using GNU Parallel:
61
+ ```bash
62
+ python datasets/habitat_sim/generate_from_metadata_files.py --input_dir=$METADATA_DIR --output_dir=$PAIRS_DATASET_DIR | parallel -j 16
63
+ ```
64
+
65
+ ## Metadata generation
66
+
67
+ Image pairs were randomly sampled using the following commands, whose outputs contain randomness and are thus not exactly reproducible:
68
+ ```bash
69
+ # Print commandlines to generate image pairs from the different scenes available.
70
+ PAIRS_DATASET_DIR=MY_CUSTOM_PATH
71
+ python datasets/habitat_sim/generate_multiview_images.py --list_commands --output_dir=$PAIRS_DATASET_DIR
72
+
73
+ # Once a dataset is generated, pack metadata files for reproducibility.
74
+ METADATA_DIR=MY_CUSTON_PATH
75
+ python datasets/habitat_sim/pack_metadata_files.py $PAIRS_DATASET_DIR $METADATA_DIR
76
+ ```
third_party/dust3r/croco/datasets/habitat_sim/__init__.py ADDED
File without changes
third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script to generate image pairs for a given scene reproducing poses provided in a metadata file.
6
+ """
7
+ import os
8
+ from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator
9
+ from datasets.habitat_sim.paths import SCENES_DATASET
10
+ import argparse
11
+ import quaternion
12
+ import PIL.Image
13
+ import cv2
14
+ import json
15
+ from tqdm import tqdm
16
+
17
+ def generate_multiview_images_from_metadata(metadata_filename,
18
+ output_dir,
19
+ overload_params = dict(),
20
+ scene_datasets_paths=None,
21
+ exist_ok=False):
22
+ """
23
+ Generate images from a metadata file for reproducibility purposes.
24
+ """
25
+ # Reorder paths by decreasing label length, to avoid collisions when testing if a string by such label
26
+ if scene_datasets_paths is not None:
27
+ scene_datasets_paths = dict(sorted(scene_datasets_paths.items(), key= lambda x: len(x[0]), reverse=True))
28
+
29
+ with open(metadata_filename, 'r') as f:
30
+ input_metadata = json.load(f)
31
+ metadata = dict()
32
+ for key, value in input_metadata.items():
33
+ # Optionally replace some paths
34
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
35
+ if scene_datasets_paths is not None:
36
+ for dataset_label, dataset_path in scene_datasets_paths.items():
37
+ if value.startswith(dataset_label):
38
+ value = os.path.normpath(os.path.join(dataset_path, os.path.relpath(value, dataset_label)))
39
+ break
40
+ metadata[key] = value
41
+
42
+ # Overload some parameters
43
+ for key, value in overload_params.items():
44
+ metadata[key] = value
45
+
46
+ generation_entries = dict([(key, value) for key, value in metadata.items() if not (key in ('multiviews', 'output_dir', 'generate_depth'))])
47
+ generate_depth = metadata["generate_depth"]
48
+
49
+ os.makedirs(output_dir, exist_ok=exist_ok)
50
+
51
+ generator = MultiviewHabitatSimGenerator(**generation_entries)
52
+
53
+ # Generate views
54
+ for idx_label, data in tqdm(metadata['multiviews'].items()):
55
+ positions = data["positions"]
56
+ orientations = data["orientations"]
57
+ n = len(positions)
58
+ for oidx in range(n):
59
+ observation = generator.render_viewpoint(positions[oidx], quaternion.from_float_array(orientations[oidx]))
60
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
61
+ # Color image saved using PIL
62
+ img = PIL.Image.fromarray(observation['color'][:,:,:3])
63
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
64
+ img.save(filename)
65
+ if generate_depth:
66
+ # Depth image as EXR file
67
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
68
+ cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
69
+ # Camera parameters
70
+ camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
71
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
72
+ with open(filename, "w") as f:
73
+ json.dump(camera_params, f)
74
+ # Save metadata
75
+ with open(os.path.join(output_dir, "metadata.json"), "w") as f:
76
+ json.dump(metadata, f)
77
+
78
+ generator.close()
79
+
80
+ if __name__ == "__main__":
81
+ parser = argparse.ArgumentParser()
82
+ parser.add_argument("--metadata_filename", required=True)
83
+ parser.add_argument("--output_dir", required=True)
84
+ args = parser.parse_args()
85
+
86
+ generate_multiview_images_from_metadata(metadata_filename=args.metadata_filename,
87
+ output_dir=args.output_dir,
88
+ scene_datasets_paths=SCENES_DATASET,
89
+ overload_params=dict(),
90
+ exist_ok=True)
91
+
92
+
third_party/dust3r/croco/datasets/habitat_sim/generate_from_metadata_files.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Script generating commandlines to generate image pairs from metadata files.
6
+ """
7
+ import os
8
+ import glob
9
+ from tqdm import tqdm
10
+ import argparse
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--input_dir", required=True)
15
+ parser.add_argument("--output_dir", required=True)
16
+ parser.add_argument("--prefix", default="", help="Commanline prefix, useful e.g. to setup environment.")
17
+ args = parser.parse_args()
18
+
19
+ input_metadata_filenames = glob.iglob(f"{args.input_dir}/**/metadata.json", recursive=True)
20
+
21
+ for metadata_filename in tqdm(input_metadata_filenames):
22
+ output_dir = os.path.join(args.output_dir, os.path.relpath(os.path.dirname(metadata_filename), args.input_dir))
23
+ # Do not process the scene if the metadata file already exists
24
+ if os.path.exists(os.path.join(output_dir, "metadata.json")):
25
+ continue
26
+ commandline = f"{args.prefix}python datasets/habitat_sim/generate_from_metadata.py --metadata_filename={metadata_filename} --output_dir={output_dir}"
27
+ print(commandline)
third_party/dust3r/croco/datasets/habitat_sim/generate_multiview_images.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from tqdm import tqdm
6
+ import argparse
7
+ import PIL.Image
8
+ import numpy as np
9
+ import json
10
+ from datasets.habitat_sim.multiview_habitat_sim_generator import MultiviewHabitatSimGenerator, NoNaviguableSpaceError
11
+ from datasets.habitat_sim.paths import list_scenes_available
12
+ import cv2
13
+ import quaternion
14
+ import shutil
15
+
16
+ def generate_multiview_images_for_scene(scene_dataset_config_file,
17
+ scene,
18
+ navmesh,
19
+ output_dir,
20
+ views_count,
21
+ size,
22
+ exist_ok=False,
23
+ generate_depth=False,
24
+ **kwargs):
25
+ """
26
+ Generate tuples of overlapping views for a given scene.
27
+ generate_depth: generate depth images and camera parameters.
28
+ """
29
+ if os.path.exists(output_dir) and not exist_ok:
30
+ print(f"Scene {scene}: data already generated. Ignoring generation.")
31
+ return
32
+ try:
33
+ print(f"Scene {scene}: {size} multiview acquisitions to generate...")
34
+ os.makedirs(output_dir, exist_ok=exist_ok)
35
+
36
+ metadata_filename = os.path.join(output_dir, "metadata.json")
37
+
38
+ metadata_template = dict(scene_dataset_config_file=scene_dataset_config_file,
39
+ scene=scene,
40
+ navmesh=navmesh,
41
+ views_count=views_count,
42
+ size=size,
43
+ generate_depth=generate_depth,
44
+ **kwargs)
45
+ metadata_template["multiviews"] = dict()
46
+
47
+ if os.path.exists(metadata_filename):
48
+ print("Metadata file already exists:", metadata_filename)
49
+ print("Loading already generated metadata file...")
50
+ with open(metadata_filename, "r") as f:
51
+ metadata = json.load(f)
52
+
53
+ for key in metadata_template.keys():
54
+ if key != "multiviews":
55
+ assert metadata_template[key] == metadata[key], f"existing file is inconsistent with the input parameters:\nKey: {key}\nmetadata: {metadata[key]}\ntemplate: {metadata_template[key]}."
56
+ else:
57
+ print("No temporary file found. Starting generation from scratch...")
58
+ metadata = metadata_template
59
+
60
+ starting_id = len(metadata["multiviews"])
61
+ print(f"Starting generation from index {starting_id}/{size}...")
62
+ if starting_id >= size:
63
+ print("Generation already done.")
64
+ return
65
+
66
+ generator = MultiviewHabitatSimGenerator(scene_dataset_config_file=scene_dataset_config_file,
67
+ scene=scene,
68
+ navmesh=navmesh,
69
+ views_count = views_count,
70
+ size = size,
71
+ **kwargs)
72
+
73
+ for idx in tqdm(range(starting_id, size)):
74
+ # Generate / re-generate the observations
75
+ try:
76
+ data = generator[idx]
77
+ observations = data["observations"]
78
+ positions = data["positions"]
79
+ orientations = data["orientations"]
80
+
81
+ idx_label = f"{idx:08}"
82
+ for oidx, observation in enumerate(observations):
83
+ observation_label = f"{oidx + 1}" # Leonid is indexing starting from 1
84
+ # Color image saved using PIL
85
+ img = PIL.Image.fromarray(observation['color'][:,:,:3])
86
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}.jpeg")
87
+ img.save(filename)
88
+ if generate_depth:
89
+ # Depth image as EXR file
90
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_depth.exr")
91
+ cv2.imwrite(filename, observation['depth'], [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])
92
+ # Camera parameters
93
+ camera_params = dict([(key, observation[key].tolist()) for key in ("camera_intrinsics", "R_cam2world", "t_cam2world")])
94
+ filename = os.path.join(output_dir, f"{idx_label}_{observation_label}_camera_params.json")
95
+ with open(filename, "w") as f:
96
+ json.dump(camera_params, f)
97
+ metadata["multiviews"][idx_label] = {"positions": positions.tolist(),
98
+ "orientations": orientations.tolist(),
99
+ "covisibility_ratios": data["covisibility_ratios"].tolist(),
100
+ "valid_fractions": data["valid_fractions"].tolist(),
101
+ "pairwise_visibility_ratios": data["pairwise_visibility_ratios"].tolist()}
102
+ except RecursionError:
103
+ print("Recursion error: unable to sample observations for this scene. We will stop there.")
104
+ break
105
+
106
+ # Regularly save a temporary metadata file, in case we need to restart the generation
107
+ if idx % 10 == 0:
108
+ with open(metadata_filename, "w") as f:
109
+ json.dump(metadata, f)
110
+
111
+ # Save metadata
112
+ with open(metadata_filename, "w") as f:
113
+ json.dump(metadata, f)
114
+
115
+ generator.close()
116
+ except NoNaviguableSpaceError:
117
+ pass
118
+
119
+ def create_commandline(scene_data, generate_depth, exist_ok=False):
120
+ """
121
+ Create a commandline string to generate a scene.
122
+ """
123
+ def my_formatting(val):
124
+ if val is None or val == "":
125
+ return '""'
126
+ else:
127
+ return val
128
+ commandline = f"""python {__file__} --scene {my_formatting(scene_data.scene)}
129
+ --scene_dataset_config_file {my_formatting(scene_data.scene_dataset_config_file)}
130
+ --navmesh {my_formatting(scene_data.navmesh)}
131
+ --output_dir {my_formatting(scene_data.output_dir)}
132
+ --generate_depth {int(generate_depth)}
133
+ --exist_ok {int(exist_ok)}
134
+ """
135
+ commandline = " ".join(commandline.split())
136
+ return commandline
137
+
138
+ if __name__ == "__main__":
139
+ os.umask(2)
140
+
141
+ parser = argparse.ArgumentParser(description="""Example of use -- listing commands to generate data for scenes available:
142
+ > python datasets/habitat_sim/generate_multiview_habitat_images.py --list_commands
143
+ """)
144
+
145
+ parser.add_argument("--output_dir", type=str, required=True)
146
+ parser.add_argument("--list_commands", action='store_true', help="list commandlines to run if true")
147
+ parser.add_argument("--scene", type=str, default="")
148
+ parser.add_argument("--scene_dataset_config_file", type=str, default="")
149
+ parser.add_argument("--navmesh", type=str, default="")
150
+
151
+ parser.add_argument("--generate_depth", type=int, default=1)
152
+ parser.add_argument("--exist_ok", type=int, default=0)
153
+
154
+ kwargs = dict(resolution=(256,256), hfov=60, views_count = 2, size=1000)
155
+
156
+ args = parser.parse_args()
157
+ generate_depth=bool(args.generate_depth)
158
+ exist_ok = bool(args.exist_ok)
159
+
160
+ if args.list_commands:
161
+ # Listing scenes available...
162
+ scenes_data = list_scenes_available(base_output_dir=args.output_dir)
163
+
164
+ for scene_data in scenes_data:
165
+ print(create_commandline(scene_data, generate_depth=generate_depth, exist_ok=exist_ok))
166
+ else:
167
+ if args.scene == "" or args.output_dir == "":
168
+ print("Missing scene or output dir argument!")
169
+ print(parser.format_help())
170
+ else:
171
+ generate_multiview_images_for_scene(scene=args.scene,
172
+ scene_dataset_config_file = args.scene_dataset_config_file,
173
+ navmesh = args.navmesh,
174
+ output_dir = args.output_dir,
175
+ exist_ok=exist_ok,
176
+ generate_depth=generate_depth,
177
+ **kwargs)
third_party/dust3r/croco/datasets/habitat_sim/multiview_habitat_sim_generator.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ import numpy as np
6
+ import quaternion
7
+ import habitat_sim
8
+ import json
9
+ from sklearn.neighbors import NearestNeighbors
10
+ import cv2
11
+
12
+ # OpenCV to habitat camera convention transformation
13
+ R_OPENCV2HABITAT = np.stack((habitat_sim.geo.RIGHT, -habitat_sim.geo.UP, habitat_sim.geo.FRONT), axis=0)
14
+ R_HABITAT2OPENCV = R_OPENCV2HABITAT.T
15
+ DEG2RAD = np.pi / 180
16
+
17
+ def compute_camera_intrinsics(height, width, hfov):
18
+ f = width/2 / np.tan(hfov/2 * np.pi/180)
19
+ cu, cv = width/2, height/2
20
+ return f, cu, cv
21
+
22
+ def compute_camera_pose_opencv_convention(camera_position, camera_orientation):
23
+ R_cam2world = quaternion.as_rotation_matrix(camera_orientation) @ R_OPENCV2HABITAT
24
+ t_cam2world = np.asarray(camera_position)
25
+ return R_cam2world, t_cam2world
26
+
27
+ def compute_pointmap(depthmap, hfov):
28
+ """ Compute a HxWx3 pointmap in camera frame from a HxW depth map."""
29
+ height, width = depthmap.shape
30
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
31
+ # Cast depth map to point
32
+ z_cam = depthmap
33
+ u, v = np.meshgrid(range(width), range(height))
34
+ x_cam = (u - cu) / f * z_cam
35
+ y_cam = (v - cv) / f * z_cam
36
+ X_cam = np.stack((x_cam, y_cam, z_cam), axis=-1)
37
+ return X_cam
38
+
39
+ def compute_pointcloud(depthmap, hfov, camera_position, camera_rotation):
40
+ """Return a 3D point cloud corresponding to valid pixels of the depth map"""
41
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_position, camera_rotation)
42
+
43
+ X_cam = compute_pointmap(depthmap=depthmap, hfov=hfov)
44
+ valid_mask = (X_cam[:,:,2] != 0.0)
45
+
46
+ X_cam = X_cam.reshape(-1, 3)[valid_mask.flatten()]
47
+ X_world = X_cam @ R_cam2world.T + t_cam2world.reshape(1, 3)
48
+ return X_world
49
+
50
+ def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold, compute_symmetric=False):
51
+ """
52
+ Compute 'overlapping' metrics based on a distance threshold between two point clouds.
53
+ """
54
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2)
55
+ distances, indices = nbrs.kneighbors(pointcloud1)
56
+ intersection1 = np.count_nonzero(distances.flatten() < distance_threshold)
57
+
58
+ data = {"intersection1": intersection1,
59
+ "size1": len(pointcloud1)}
60
+ if compute_symmetric:
61
+ nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud1)
62
+ distances, indices = nbrs.kneighbors(pointcloud2)
63
+ intersection2 = np.count_nonzero(distances.flatten() < distance_threshold)
64
+ data["intersection2"] = intersection2
65
+ data["size2"] = len(pointcloud2)
66
+
67
+ return data
68
+
69
+ def _append_camera_parameters(observation, hfov, camera_location, camera_rotation):
70
+ """
71
+ Add camera parameters to the observation dictionnary produced by Habitat-Sim
72
+ In-place modifications.
73
+ """
74
+ R_cam2world, t_cam2world = compute_camera_pose_opencv_convention(camera_location, camera_rotation)
75
+ height, width = observation['depth'].shape
76
+ f, cu, cv = compute_camera_intrinsics(height, width, hfov)
77
+ K = np.asarray([[f, 0, cu],
78
+ [0, f, cv],
79
+ [0, 0, 1.0]])
80
+ observation["camera_intrinsics"] = K
81
+ observation["t_cam2world"] = t_cam2world
82
+ observation["R_cam2world"] = R_cam2world
83
+
84
+ def look_at(eye, center, up, return_cam2world=True):
85
+ """
86
+ Return camera pose looking at a given center point.
87
+ Analogous of gluLookAt function, using OpenCV camera convention.
88
+ """
89
+ z = center - eye
90
+ z /= np.linalg.norm(z, axis=-1, keepdims=True)
91
+ y = -up
92
+ y = y - np.sum(y * z, axis=-1, keepdims=True) * z
93
+ y /= np.linalg.norm(y, axis=-1, keepdims=True)
94
+ x = np.cross(y, z, axis=-1)
95
+
96
+ if return_cam2world:
97
+ R = np.stack((x, y, z), axis=-1)
98
+ t = eye
99
+ else:
100
+ # World to camera transformation
101
+ # Transposed matrix
102
+ R = np.stack((x, y, z), axis=-2)
103
+ t = - np.einsum('...ij, ...j', R, eye)
104
+ return R, t
105
+
106
+ def look_at_for_habitat(eye, center, up, return_cam2world=True):
107
+ R, t = look_at(eye, center, up)
108
+ orientation = quaternion.from_rotation_matrix(R @ R_OPENCV2HABITAT.T)
109
+ return orientation, t
110
+
111
+ def generate_orientation_noise(pan_range, tilt_range, roll_range):
112
+ return (quaternion.from_rotation_vector(np.random.uniform(*pan_range) * DEG2RAD * habitat_sim.geo.UP)
113
+ * quaternion.from_rotation_vector(np.random.uniform(*tilt_range) * DEG2RAD * habitat_sim.geo.RIGHT)
114
+ * quaternion.from_rotation_vector(np.random.uniform(*roll_range) * DEG2RAD * habitat_sim.geo.FRONT))
115
+
116
+
117
+ class NoNaviguableSpaceError(RuntimeError):
118
+ def __init__(self, *args):
119
+ super().__init__(*args)
120
+
121
+ class MultiviewHabitatSimGenerator:
122
+ def __init__(self,
123
+ scene,
124
+ navmesh,
125
+ scene_dataset_config_file,
126
+ resolution = (240, 320),
127
+ views_count=2,
128
+ hfov = 60,
129
+ gpu_id = 0,
130
+ size = 10000,
131
+ minimum_covisibility = 0.5,
132
+ transform = None):
133
+ self.scene = scene
134
+ self.navmesh = navmesh
135
+ self.scene_dataset_config_file = scene_dataset_config_file
136
+ self.resolution = resolution
137
+ self.views_count = views_count
138
+ assert(self.views_count >= 1)
139
+ self.hfov = hfov
140
+ self.gpu_id = gpu_id
141
+ self.size = size
142
+ self.transform = transform
143
+
144
+ # Noise added to camera orientation
145
+ self.pan_range = (-3, 3)
146
+ self.tilt_range = (-10, 10)
147
+ self.roll_range = (-5, 5)
148
+
149
+ # Height range to sample cameras
150
+ self.height_range = (1.2, 1.8)
151
+
152
+ # Random steps between the camera views
153
+ self.random_steps_count = 5
154
+ self.random_step_variance = 2.0
155
+
156
+ # Minimum fraction of the scene which should be valid (well defined depth)
157
+ self.minimum_valid_fraction = 0.7
158
+
159
+ # Distance threshold to see to select pairs
160
+ self.distance_threshold = 0.05
161
+ # Minimum IoU of a view point cloud with respect to the reference view to be kept.
162
+ self.minimum_covisibility = minimum_covisibility
163
+
164
+ # Maximum number of retries.
165
+ self.max_attempts_count = 100
166
+
167
+ self.seed = None
168
+ self._lazy_initialization()
169
+
170
+ def _lazy_initialization(self):
171
+ # Lazy random seeding and instantiation of the simulator to deal with multiprocessing properly
172
+ if self.seed == None:
173
+ # Re-seed numpy generator
174
+ np.random.seed()
175
+ self.seed = np.random.randint(2**32-1)
176
+ sim_cfg = habitat_sim.SimulatorConfiguration()
177
+ sim_cfg.scene_id = self.scene
178
+ if self.scene_dataset_config_file is not None and self.scene_dataset_config_file != "":
179
+ sim_cfg.scene_dataset_config_file = self.scene_dataset_config_file
180
+ sim_cfg.random_seed = self.seed
181
+ sim_cfg.load_semantic_mesh = False
182
+ sim_cfg.gpu_device_id = self.gpu_id
183
+
184
+ depth_sensor_spec = habitat_sim.CameraSensorSpec()
185
+ depth_sensor_spec.uuid = "depth"
186
+ depth_sensor_spec.sensor_type = habitat_sim.SensorType.DEPTH
187
+ depth_sensor_spec.resolution = self.resolution
188
+ depth_sensor_spec.hfov = self.hfov
189
+ depth_sensor_spec.position = [0.0, 0.0, 0]
190
+ depth_sensor_spec.orientation
191
+
192
+ rgb_sensor_spec = habitat_sim.CameraSensorSpec()
193
+ rgb_sensor_spec.uuid = "color"
194
+ rgb_sensor_spec.sensor_type = habitat_sim.SensorType.COLOR
195
+ rgb_sensor_spec.resolution = self.resolution
196
+ rgb_sensor_spec.hfov = self.hfov
197
+ rgb_sensor_spec.position = [0.0, 0.0, 0]
198
+ agent_cfg = habitat_sim.agent.AgentConfiguration(sensor_specifications=[rgb_sensor_spec, depth_sensor_spec])
199
+
200
+ cfg = habitat_sim.Configuration(sim_cfg, [agent_cfg])
201
+ self.sim = habitat_sim.Simulator(cfg)
202
+ if self.navmesh is not None and self.navmesh != "":
203
+ # Use pre-computed navmesh when available (usually better than those generated automatically)
204
+ self.sim.pathfinder.load_nav_mesh(self.navmesh)
205
+
206
+ if not self.sim.pathfinder.is_loaded:
207
+ # Try to compute a navmesh
208
+ navmesh_settings = habitat_sim.NavMeshSettings()
209
+ navmesh_settings.set_defaults()
210
+ self.sim.recompute_navmesh(self.sim.pathfinder, navmesh_settings, True)
211
+
212
+ # Ensure that the navmesh is not empty
213
+ if not self.sim.pathfinder.is_loaded:
214
+ raise NoNaviguableSpaceError(f"No naviguable location (scene: {self.scene} -- navmesh: {self.navmesh})")
215
+
216
+ self.agent = self.sim.initialize_agent(agent_id=0)
217
+
218
+ def close(self):
219
+ self.sim.close()
220
+
221
+ def __del__(self):
222
+ self.sim.close()
223
+
224
+ def __len__(self):
225
+ return self.size
226
+
227
+ def sample_random_viewpoint(self):
228
+ """ Sample a random viewpoint using the navmesh """
229
+ nav_point = self.sim.pathfinder.get_random_navigable_point()
230
+
231
+ # Sample a random viewpoint height
232
+ viewpoint_height = np.random.uniform(*self.height_range)
233
+ viewpoint_position = nav_point + viewpoint_height * habitat_sim.geo.UP
234
+ viewpoint_orientation = quaternion.from_rotation_vector(np.random.uniform(0, 2 * np.pi) * habitat_sim.geo.UP) * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
235
+ return viewpoint_position, viewpoint_orientation, nav_point
236
+
237
+ def sample_other_random_viewpoint(self, observed_point, nav_point):
238
+ """ Sample a random viewpoint close to an existing one, using the navmesh and a reference observed point."""
239
+ other_nav_point = nav_point
240
+
241
+ walk_directions = self.random_step_variance * np.asarray([1,0,1])
242
+ for i in range(self.random_steps_count):
243
+ temp = self.sim.pathfinder.snap_point(other_nav_point + walk_directions * np.random.normal(size=3))
244
+ # Snapping may return nan when it fails
245
+ if not np.isnan(temp[0]):
246
+ other_nav_point = temp
247
+
248
+ other_viewpoint_height = np.random.uniform(*self.height_range)
249
+ other_viewpoint_position = other_nav_point + other_viewpoint_height * habitat_sim.geo.UP
250
+
251
+ # Set viewing direction towards the central point
252
+ rotation, position = look_at_for_habitat(eye=other_viewpoint_position, center=observed_point, up=habitat_sim.geo.UP, return_cam2world=True)
253
+ rotation = rotation * generate_orientation_noise(self.pan_range, self.tilt_range, self.roll_range)
254
+ return position, rotation, other_nav_point
255
+
256
+ def is_other_pointcloud_overlapping(self, ref_pointcloud, other_pointcloud):
257
+ """ Check if a viewpoint is valid and overlaps significantly with a reference one. """
258
+ # Observation
259
+ pixels_count = self.resolution[0] * self.resolution[1]
260
+ valid_fraction = len(other_pointcloud) / pixels_count
261
+ assert valid_fraction <= 1.0 and valid_fraction >= 0.0
262
+ overlap = compute_pointcloud_overlaps_scikit(ref_pointcloud, other_pointcloud, self.distance_threshold, compute_symmetric=True)
263
+ covisibility = min(overlap["intersection1"] / pixels_count, overlap["intersection2"] / pixels_count)
264
+ is_valid = (valid_fraction >= self.minimum_valid_fraction) and (covisibility >= self.minimum_covisibility)
265
+ return is_valid, valid_fraction, covisibility
266
+
267
+ def is_other_viewpoint_overlapping(self, ref_pointcloud, observation, position, rotation):
268
+ """ Check if a viewpoint is valid and overlaps significantly with a reference one. """
269
+ # Observation
270
+ other_pointcloud = compute_pointcloud(observation['depth'], self.hfov, position, rotation)
271
+ return self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
272
+
273
+ def render_viewpoint(self, viewpoint_position, viewpoint_orientation):
274
+ agent_state = habitat_sim.AgentState()
275
+ agent_state.position = viewpoint_position
276
+ agent_state.rotation = viewpoint_orientation
277
+ self.agent.set_state(agent_state)
278
+ viewpoint_observations = self.sim.get_sensor_observations(agent_ids=0)
279
+ _append_camera_parameters(viewpoint_observations, self.hfov, viewpoint_position, viewpoint_orientation)
280
+ return viewpoint_observations
281
+
282
+ def __getitem__(self, useless_idx):
283
+ ref_position, ref_orientation, nav_point = self.sample_random_viewpoint()
284
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
285
+ # Extract point cloud
286
+ ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
287
+ camera_position=ref_position, camera_rotation=ref_orientation)
288
+
289
+ pixels_count = self.resolution[0] * self.resolution[1]
290
+ ref_valid_fraction = len(ref_pointcloud) / pixels_count
291
+ assert ref_valid_fraction <= 1.0 and ref_valid_fraction >= 0.0
292
+ if ref_valid_fraction < self.minimum_valid_fraction:
293
+ # This should produce a recursion error at some point when something is very wrong.
294
+ return self[0]
295
+ # Pick an reference observed point in the point cloud
296
+ observed_point = np.mean(ref_pointcloud, axis=0)
297
+
298
+ # Add the first image as reference
299
+ viewpoints_observations = [ref_observations]
300
+ viewpoints_covisibility = [ref_valid_fraction]
301
+ viewpoints_positions = [ref_position]
302
+ viewpoints_orientations = [quaternion.as_float_array(ref_orientation)]
303
+ viewpoints_clouds = [ref_pointcloud]
304
+ viewpoints_valid_fractions = [ref_valid_fraction]
305
+
306
+ for _ in range(self.views_count - 1):
307
+ # Generate an other viewpoint using some dummy random walk
308
+ successful_sampling = False
309
+ for sampling_attempt in range(self.max_attempts_count):
310
+ position, rotation, _ = self.sample_other_random_viewpoint(observed_point, nav_point)
311
+ # Observation
312
+ other_viewpoint_observations = self.render_viewpoint(position, rotation)
313
+ other_pointcloud = compute_pointcloud(other_viewpoint_observations['depth'], self.hfov, position, rotation)
314
+
315
+ is_valid, valid_fraction, covisibility = self.is_other_pointcloud_overlapping(ref_pointcloud, other_pointcloud)
316
+ if is_valid:
317
+ successful_sampling = True
318
+ break
319
+ if not successful_sampling:
320
+ print("WARNING: Maximum number of attempts reached.")
321
+ # Dirty hack, try using a novel original viewpoint
322
+ return self[0]
323
+ viewpoints_observations.append(other_viewpoint_observations)
324
+ viewpoints_covisibility.append(covisibility)
325
+ viewpoints_positions.append(position)
326
+ viewpoints_orientations.append(quaternion.as_float_array(rotation)) # WXYZ convention for the quaternion encoding.
327
+ viewpoints_clouds.append(other_pointcloud)
328
+ viewpoints_valid_fractions.append(valid_fraction)
329
+
330
+ # Estimate relations between all pairs of images
331
+ pairwise_visibility_ratios = np.ones((len(viewpoints_observations), len(viewpoints_observations)))
332
+ for i in range(len(viewpoints_observations)):
333
+ pairwise_visibility_ratios[i,i] = viewpoints_valid_fractions[i]
334
+ for j in range(i+1, len(viewpoints_observations)):
335
+ overlap = compute_pointcloud_overlaps_scikit(viewpoints_clouds[i], viewpoints_clouds[j], self.distance_threshold, compute_symmetric=True)
336
+ pairwise_visibility_ratios[i,j] = overlap['intersection1'] / pixels_count
337
+ pairwise_visibility_ratios[j,i] = overlap['intersection2'] / pixels_count
338
+
339
+ # IoU is relative to the image 0
340
+ data = {"observations": viewpoints_observations,
341
+ "positions": np.asarray(viewpoints_positions),
342
+ "orientations": np.asarray(viewpoints_orientations),
343
+ "covisibility_ratios": np.asarray(viewpoints_covisibility),
344
+ "valid_fractions": np.asarray(viewpoints_valid_fractions, dtype=float),
345
+ "pairwise_visibility_ratios": np.asarray(pairwise_visibility_ratios, dtype=float),
346
+ }
347
+
348
+ if self.transform is not None:
349
+ data = self.transform(data)
350
+ return data
351
+
352
+ def generate_random_spiral_trajectory(self, images_count = 100, max_radius=0.5, half_turns=5, use_constant_orientation=False):
353
+ """
354
+ Return a list of images corresponding to a spiral trajectory from a random starting point.
355
+ Useful to generate nice visualisations.
356
+ Use an even number of half turns to get a nice "C1-continuous" loop effect
357
+ """
358
+ ref_position, ref_orientation, navpoint = self.sample_random_viewpoint()
359
+ ref_observations = self.render_viewpoint(ref_position, ref_orientation)
360
+ ref_pointcloud = compute_pointcloud(depthmap=ref_observations['depth'], hfov=self.hfov,
361
+ camera_position=ref_position, camera_rotation=ref_orientation)
362
+ pixels_count = self.resolution[0] * self.resolution[1]
363
+ if len(ref_pointcloud) / pixels_count < self.minimum_valid_fraction:
364
+ # Dirty hack: ensure that the valid part of the image is significant
365
+ return self.generate_random_spiral_trajectory(images_count, max_radius, half_turns, use_constant_orientation)
366
+
367
+ # Pick an observed point in the point cloud
368
+ observed_point = np.mean(ref_pointcloud, axis=0)
369
+ ref_R, ref_t = compute_camera_pose_opencv_convention(ref_position, ref_orientation)
370
+
371
+ images = []
372
+ is_valid = []
373
+ # Spiral trajectory, use_constant orientation
374
+ for i, alpha in enumerate(np.linspace(0, 1, images_count)):
375
+ r = max_radius * np.abs(np.sin(alpha * np.pi)) # Increase then decrease the radius
376
+ theta = alpha * half_turns * np.pi
377
+ x = r * np.cos(theta)
378
+ y = r * np.sin(theta)
379
+ z = 0.0
380
+ position = ref_position + (ref_R @ np.asarray([x, y, z]).reshape(3,1)).flatten()
381
+ if use_constant_orientation:
382
+ orientation = ref_orientation
383
+ else:
384
+ # trajectory looking at a mean point in front of the ref observation
385
+ orientation, position = look_at_for_habitat(eye=position, center=observed_point, up=habitat_sim.geo.UP)
386
+ observations = self.render_viewpoint(position, orientation)
387
+ images.append(observations['color'][...,:3])
388
+ _is_valid, valid_fraction, iou = self.is_other_viewpoint_overlapping(ref_pointcloud, observations, position, orientation)
389
+ is_valid.append(_is_valid)
390
+ return images, np.all(is_valid)
third_party/dust3r/croco/datasets/habitat_sim/pack_metadata_files.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ """
4
+ Utility script to pack metadata files of the dataset in order to be able to re-generate it elsewhere.
5
+ """
6
+ import os
7
+ import glob
8
+ from tqdm import tqdm
9
+ import shutil
10
+ import json
11
+ from datasets.habitat_sim.paths import *
12
+ import argparse
13
+ import collections
14
+
15
+ if __name__ == "__main__":
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument("input_dir")
18
+ parser.add_argument("output_dir")
19
+ args = parser.parse_args()
20
+
21
+ input_dirname = args.input_dir
22
+ output_dirname = args.output_dir
23
+
24
+ input_metadata_filenames = glob.iglob(f"{input_dirname}/**/metadata.json", recursive=True)
25
+
26
+ images_count = collections.defaultdict(lambda : 0)
27
+
28
+ os.makedirs(output_dirname)
29
+ for input_filename in tqdm(input_metadata_filenames):
30
+ # Ignore empty files
31
+ with open(input_filename, "r") as f:
32
+ original_metadata = json.load(f)
33
+ if "multiviews" not in original_metadata or len(original_metadata["multiviews"]) == 0:
34
+ print("No views in", input_filename)
35
+ continue
36
+
37
+ relpath = os.path.relpath(input_filename, input_dirname)
38
+ print(relpath)
39
+
40
+ # Copy metadata, while replacing scene paths by generic keys depending on the dataset, for portability.
41
+ # Data paths are sorted by decreasing length to avoid potential bugs due to paths starting by the same string pattern.
42
+ scenes_dataset_paths = dict(sorted(SCENES_DATASET.items(), key=lambda x: len(x[1]), reverse=True))
43
+ metadata = dict()
44
+ for key, value in original_metadata.items():
45
+ if key in ("scene_dataset_config_file", "scene", "navmesh") and value != "":
46
+ known_path = False
47
+ for dataset, dataset_path in scenes_dataset_paths.items():
48
+ if value.startswith(dataset_path):
49
+ value = os.path.join(dataset, os.path.relpath(value, dataset_path))
50
+ known_path = True
51
+ break
52
+ if not known_path:
53
+ raise KeyError("Unknown path:" + value)
54
+ metadata[key] = value
55
+
56
+ # Compile some general statistics while packing data
57
+ scene_split = metadata["scene"].split("/")
58
+ upper_level = "/".join(scene_split[:2]) if scene_split[0] == "hm3d" else scene_split[0]
59
+ images_count[upper_level] += len(metadata["multiviews"])
60
+
61
+ output_filename = os.path.join(output_dirname, relpath)
62
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
63
+ with open(output_filename, "w") as f:
64
+ json.dump(metadata, f)
65
+
66
+ # Print statistics
67
+ print("Images count:")
68
+ for upper_level, count in images_count.items():
69
+ print(f"- {upper_level}: {count}")
third_party/dust3r/croco/datasets/habitat_sim/paths.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ """
5
+ Paths to Habitat-Sim scenes
6
+ """
7
+
8
+ import os
9
+ import json
10
+ import collections
11
+ from tqdm import tqdm
12
+
13
+
14
+ # Hardcoded path to the different scene datasets
15
+ SCENES_DATASET = {
16
+ "hm3d": "./data/habitat-sim-data/scene_datasets/hm3d/",
17
+ "gibson": "./data/habitat-sim-data/scene_datasets/gibson/",
18
+ "habitat-test-scenes": "./data/habitat-sim/scene_datasets/habitat-test-scenes/",
19
+ "replica_cad_baked_lighting": "./data/habitat-sim/scene_datasets/replica_cad_baked_lighting/",
20
+ "replica_cad": "./data/habitat-sim/scene_datasets/replica_cad/",
21
+ "replica": "./data/habitat-sim/scene_datasets/ReplicaDataset/",
22
+ "scannet": "./data/habitat-sim/scene_datasets/scannet/"
23
+ }
24
+
25
+ SceneData = collections.namedtuple("SceneData", ["scene_dataset_config_file", "scene", "navmesh", "output_dir"])
26
+
27
+ def list_replicacad_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad"]):
28
+ scene_dataset_config_file = os.path.join(base_path, "replicaCAD.scene_dataset_config.json")
29
+ scenes = [f"apt_{i}" for i in range(6)] + ["empty_stage"]
30
+ navmeshes = [f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
31
+ scenes_data = []
32
+ for idx in range(len(scenes)):
33
+ output_dir = os.path.join(base_output_dir, "ReplicaCAD", scenes[idx])
34
+ # Add scene
35
+ data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
36
+ scene = scenes[idx] + ".scene_instance.json",
37
+ navmesh = os.path.join(base_path, navmeshes[idx]),
38
+ output_dir = output_dir)
39
+ scenes_data.append(data)
40
+ return scenes_data
41
+
42
+ def list_replica_cad_baked_lighting_scenes(base_output_dir, base_path=SCENES_DATASET["replica_cad_baked_lighting"]):
43
+ scene_dataset_config_file = os.path.join(base_path, "replicaCAD_baked.scene_dataset_config.json")
44
+ scenes = sum([[f"Baked_sc{i}_staging_{j:02}" for i in range(5)] for j in range(21)], [])
45
+ navmeshes = ""#[f"navmeshes/apt_{i}_static_furniture.navmesh" for i in range(6)] + ["empty_stage.navmesh"]
46
+ scenes_data = []
47
+ for idx in range(len(scenes)):
48
+ output_dir = os.path.join(base_output_dir, "replica_cad_baked_lighting", scenes[idx])
49
+ data = SceneData(scene_dataset_config_file=scene_dataset_config_file,
50
+ scene = scenes[idx],
51
+ navmesh = "",
52
+ output_dir = output_dir)
53
+ scenes_data.append(data)
54
+ return scenes_data
55
+
56
+ def list_replica_scenes(base_output_dir, base_path):
57
+ scenes_data = []
58
+ for scene_id in os.listdir(base_path):
59
+ scene = os.path.join(base_path, scene_id, "mesh.ply")
60
+ navmesh = os.path.join(base_path, scene_id, "habitat/mesh_preseg_semantic.navmesh") # Not sure if I should use it
61
+ scene_dataset_config_file = ""
62
+ output_dir = os.path.join(base_output_dir, scene_id)
63
+ # Add scene only if it does not exist already, or if exist_ok
64
+ data = SceneData(scene_dataset_config_file = scene_dataset_config_file,
65
+ scene = scene,
66
+ navmesh = navmesh,
67
+ output_dir = output_dir)
68
+ scenes_data.append(data)
69
+ return scenes_data
70
+
71
+
72
+ def list_scenes(base_output_dir, base_path):
73
+ """
74
+ Generic method iterating through a base_path folder to find scenes.
75
+ """
76
+ scenes_data = []
77
+ for root, dirs, files in os.walk(base_path, followlinks=True):
78
+ folder_scenes_data = []
79
+ for file in files:
80
+ name, ext = os.path.splitext(file)
81
+ if ext == ".glb":
82
+ scene = os.path.join(root, name + ".glb")
83
+ navmesh = os.path.join(root, name + ".navmesh")
84
+ if not os.path.exists(navmesh):
85
+ navmesh = ""
86
+ relpath = os.path.relpath(root, base_path)
87
+ output_dir = os.path.abspath(os.path.join(base_output_dir, relpath, name))
88
+ data = SceneData(scene_dataset_config_file="",
89
+ scene = scene,
90
+ navmesh = navmesh,
91
+ output_dir = output_dir)
92
+ folder_scenes_data.append(data)
93
+
94
+ # Specific check for HM3D:
95
+ # When two meshesxxxx.basis.glb and xxxx.glb are present, use the 'basis' version.
96
+ basis_scenes = [data.scene[:-len(".basis.glb")] for data in folder_scenes_data if data.scene.endswith(".basis.glb")]
97
+ if len(basis_scenes) != 0:
98
+ folder_scenes_data = [data for data in folder_scenes_data if not (data.scene[:-len(".glb")] in basis_scenes)]
99
+
100
+ scenes_data.extend(folder_scenes_data)
101
+ return scenes_data
102
+
103
+ def list_scenes_available(base_output_dir, scenes_dataset_paths=SCENES_DATASET):
104
+ scenes_data = []
105
+
106
+ # HM3D
107
+ for split in ("minival", "train", "val", "examples"):
108
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, f"hm3d/{split}/"),
109
+ base_path=f"{scenes_dataset_paths['hm3d']}/{split}")
110
+
111
+ # Gibson
112
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "gibson"),
113
+ base_path=scenes_dataset_paths["gibson"])
114
+
115
+ # Habitat test scenes (just a few)
116
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "habitat-test-scenes"),
117
+ base_path=scenes_dataset_paths["habitat-test-scenes"])
118
+
119
+ # ReplicaCAD (baked lightning)
120
+ scenes_data += list_replica_cad_baked_lighting_scenes(base_output_dir=base_output_dir)
121
+
122
+ # ScanNet
123
+ scenes_data += list_scenes(base_output_dir=os.path.join(base_output_dir, "scannet"),
124
+ base_path=scenes_dataset_paths["scannet"])
125
+
126
+ # Replica
127
+ list_replica_scenes(base_output_dir=os.path.join(base_output_dir, "replica"),
128
+ base_path=scenes_dataset_paths["replica"])
129
+ return scenes_data
third_party/dust3r/croco/datasets/pairs_dataset.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import os
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+
8
+ from datasets.transforms import get_pair_transforms
9
+
10
+ def load_image(impath):
11
+ return Image.open(impath)
12
+
13
+ def load_pairs_from_cache_file(fname, root=''):
14
+ assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
15
+ with open(fname, 'r') as fid:
16
+ lines = fid.read().strip().splitlines()
17
+ pairs = [ (os.path.join(root,l.split()[0]), os.path.join(root,l.split()[1])) for l in lines]
18
+ return pairs
19
+
20
+ def load_pairs_from_list_file(fname, root=''):
21
+ assert os.path.isfile(fname), "cannot parse pairs from {:s}, file does not exist".format(fname)
22
+ with open(fname, 'r') as fid:
23
+ lines = fid.read().strip().splitlines()
24
+ pairs = [ (os.path.join(root,l+'_1.jpg'), os.path.join(root,l+'_2.jpg')) for l in lines if not l.startswith('#')]
25
+ return pairs
26
+
27
+
28
+ def write_cache_file(fname, pairs, root=''):
29
+ if len(root)>0:
30
+ if not root.endswith('/'): root+='/'
31
+ assert os.path.isdir(root)
32
+ s = ''
33
+ for im1, im2 in pairs:
34
+ if len(root)>0:
35
+ assert im1.startswith(root), im1
36
+ assert im2.startswith(root), im2
37
+ s += '{:s} {:s}\n'.format(im1[len(root):], im2[len(root):])
38
+ with open(fname, 'w') as fid:
39
+ fid.write(s[:-1])
40
+
41
+ def parse_and_cache_all_pairs(dname, data_dir='./data/'):
42
+ if dname=='habitat_release':
43
+ dirname = os.path.join(data_dir, 'habitat_release')
44
+ assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
45
+ cache_file = os.path.join(dirname, 'pairs.txt')
46
+ assert not os.path.isfile(cache_file), "cache file already exists: "+cache_file
47
+
48
+ print('Parsing pairs for dataset: '+dname)
49
+ pairs = []
50
+ for root, dirs, files in os.walk(dirname):
51
+ if 'val' in root: continue
52
+ dirs.sort()
53
+ pairs += [ (os.path.join(root,f), os.path.join(root,f[:-len('_1.jpeg')]+'_2.jpeg')) for f in sorted(files) if f.endswith('_1.jpeg')]
54
+ print('Found {:,} pairs'.format(len(pairs)))
55
+ print('Writing cache to: '+cache_file)
56
+ write_cache_file(cache_file, pairs, root=dirname)
57
+
58
+ else:
59
+ raise NotImplementedError('Unknown dataset: '+dname)
60
+
61
+ def dnames_to_image_pairs(dnames, data_dir='./data/'):
62
+ """
63
+ dnames: list of datasets with image pairs, separated by +
64
+ """
65
+ all_pairs = []
66
+ for dname in dnames.split('+'):
67
+ if dname=='habitat_release':
68
+ dirname = os.path.join(data_dir, 'habitat_release')
69
+ assert os.path.isdir(dirname), "cannot find folder for habitat_release pairs: "+dirname
70
+ cache_file = os.path.join(dirname, 'pairs.txt')
71
+ assert os.path.isfile(cache_file), "cannot find cache file for habitat_release pairs, please first create the cache file, see instructions. "+cache_file
72
+ pairs = load_pairs_from_cache_file(cache_file, root=dirname)
73
+ elif dname in ['ARKitScenes', 'MegaDepth', '3DStreetView', 'IndoorVL']:
74
+ dirname = os.path.join(data_dir, dname+'_crops')
75
+ assert os.path.isdir(dirname), "cannot find folder for {:s} pairs: {:s}".format(dname, dirname)
76
+ list_file = os.path.join(dirname, 'listing.txt')
77
+ assert os.path.isfile(list_file), "cannot find list file for {:s} pairs, see instructions. {:s}".format(dname, list_file)
78
+ pairs = load_pairs_from_list_file(list_file, root=dirname)
79
+ print(' {:s}: {:,} pairs'.format(dname, len(pairs)))
80
+ all_pairs += pairs
81
+ if '+' in dnames: print(' Total: {:,} pairs'.format(len(all_pairs)))
82
+ return all_pairs
83
+
84
+
85
+ class PairsDataset(Dataset):
86
+
87
+ def __init__(self, dnames, trfs='', totensor=True, normalize=True, data_dir='./data/'):
88
+ super().__init__()
89
+ self.image_pairs = dnames_to_image_pairs(dnames, data_dir=data_dir)
90
+ self.transforms = get_pair_transforms(transform_str=trfs, totensor=totensor, normalize=normalize)
91
+
92
+ def __len__(self):
93
+ return len(self.image_pairs)
94
+
95
+ def __getitem__(self, index):
96
+ im1path, im2path = self.image_pairs[index]
97
+ im1 = load_image(im1path)
98
+ im2 = load_image(im2path)
99
+ if self.transforms is not None: im1, im2 = self.transforms(im1, im2)
100
+ return im1, im2
101
+
102
+
103
+ if __name__=="__main__":
104
+ import argparse
105
+ parser = argparse.ArgumentParser(prog="Computing and caching list of pairs for a given dataset")
106
+ parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored")
107
+ parser.add_argument('--dataset', default='habitat_release', type=str, help="name of the dataset")
108
+ args = parser.parse_args()
109
+ parse_and_cache_all_pairs(dname=args.dataset, data_dir=args.data_dir)
third_party/dust3r/croco/datasets/transforms.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+ import torchvision.transforms
6
+ import torchvision.transforms.functional as F
7
+
8
+ # "Pair": apply a transform on a pair
9
+ # "Both": apply the exact same transform to both images
10
+
11
+ class ComposePair(torchvision.transforms.Compose):
12
+ def __call__(self, img1, img2):
13
+ for t in self.transforms:
14
+ img1, img2 = t(img1, img2)
15
+ return img1, img2
16
+
17
+ class NormalizeBoth(torchvision.transforms.Normalize):
18
+ def forward(self, img1, img2):
19
+ img1 = super().forward(img1)
20
+ img2 = super().forward(img2)
21
+ return img1, img2
22
+
23
+ class ToTensorBoth(torchvision.transforms.ToTensor):
24
+ def __call__(self, img1, img2):
25
+ img1 = super().__call__(img1)
26
+ img2 = super().__call__(img2)
27
+ return img1, img2
28
+
29
+ class RandomCropPair(torchvision.transforms.RandomCrop):
30
+ # the crop will be intentionally different for the two images with this class
31
+ def forward(self, img1, img2):
32
+ img1 = super().forward(img1)
33
+ img2 = super().forward(img2)
34
+ return img1, img2
35
+
36
+ class ColorJitterPair(torchvision.transforms.ColorJitter):
37
+ # can be symmetric (same for both images) or assymetric (different jitter params for each image) depending on assymetric_prob
38
+ def __init__(self, assymetric_prob, **kwargs):
39
+ super().__init__(**kwargs)
40
+ self.assymetric_prob = assymetric_prob
41
+ def jitter_one(self, img, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor):
42
+ for fn_id in fn_idx:
43
+ if fn_id == 0 and brightness_factor is not None:
44
+ img = F.adjust_brightness(img, brightness_factor)
45
+ elif fn_id == 1 and contrast_factor is not None:
46
+ img = F.adjust_contrast(img, contrast_factor)
47
+ elif fn_id == 2 and saturation_factor is not None:
48
+ img = F.adjust_saturation(img, saturation_factor)
49
+ elif fn_id == 3 and hue_factor is not None:
50
+ img = F.adjust_hue(img, hue_factor)
51
+ return img
52
+
53
+ def forward(self, img1, img2):
54
+
55
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
56
+ self.brightness, self.contrast, self.saturation, self.hue
57
+ )
58
+ img1 = self.jitter_one(img1, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
59
+ if torch.rand(1) < self.assymetric_prob: # assymetric:
60
+ fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = self.get_params(
61
+ self.brightness, self.contrast, self.saturation, self.hue
62
+ )
63
+ img2 = self.jitter_one(img2, fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor)
64
+ return img1, img2
65
+
66
+ def get_pair_transforms(transform_str, totensor=True, normalize=True):
67
+ # transform_str is eg crop224+color
68
+ trfs = []
69
+ for s in transform_str.split('+'):
70
+ if s.startswith('crop'):
71
+ size = int(s[len('crop'):])
72
+ trfs.append(RandomCropPair(size))
73
+ elif s=='acolor':
74
+ trfs.append(ColorJitterPair(assymetric_prob=1.0, brightness=(0.6, 1.4), contrast=(0.6, 1.4), saturation=(0.6, 1.4), hue=0.0))
75
+ elif s=='': # if transform_str was ""
76
+ pass
77
+ else:
78
+ raise NotImplementedError('Unknown augmentation: '+s)
79
+
80
+ if totensor:
81
+ trfs.append( ToTensorBoth() )
82
+ if normalize:
83
+ trfs.append( NormalizeBoth(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) )
84
+
85
+ if len(trfs)==0:
86
+ return None
87
+ elif len(trfs)==1:
88
+ return trfs
89
+ else:
90
+ return ComposePair(trfs)
91
+
92
+
93
+
94
+
95
+
third_party/dust3r/croco/demo.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+ from models.croco import CroCoNet
6
+ from PIL import Image
7
+ import torchvision.transforms
8
+ from torchvision.transforms import ToTensor, Normalize, Compose
9
+
10
+ def main():
11
+ device = torch.device('cuda:0' if torch.cuda.is_available() and torch.cuda.device_count()>0 else 'cpu')
12
+
13
+ # load 224x224 images and transform them to tensor
14
+ imagenet_mean = [0.485, 0.456, 0.406]
15
+ imagenet_mean_tensor = torch.tensor(imagenet_mean).view(1,3,1,1).to(device, non_blocking=True)
16
+ imagenet_std = [0.229, 0.224, 0.225]
17
+ imagenet_std_tensor = torch.tensor(imagenet_std).view(1,3,1,1).to(device, non_blocking=True)
18
+ trfs = Compose([ToTensor(), Normalize(mean=imagenet_mean, std=imagenet_std)])
19
+ image1 = trfs(Image.open('assets/Chateau1.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
20
+ image2 = trfs(Image.open('assets/Chateau2.png').convert('RGB')).to(device, non_blocking=True).unsqueeze(0)
21
+
22
+ # load model
23
+ ckpt = torch.load('pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth', 'cpu')
24
+ model = CroCoNet( **ckpt.get('croco_kwargs',{})).to(device)
25
+ model.eval()
26
+ msg = model.load_state_dict(ckpt['model'], strict=True)
27
+
28
+ # forward
29
+ with torch.inference_mode():
30
+ out, mask, target = model(image1, image2)
31
+
32
+ # the output is normalized, thus use the mean/std of the actual image to go back to RGB space
33
+ patchified = model.patchify(image1)
34
+ mean = patchified.mean(dim=-1, keepdim=True)
35
+ var = patchified.var(dim=-1, keepdim=True)
36
+ decoded_image = model.unpatchify(out * (var + 1.e-6)**.5 + mean)
37
+ # undo imagenet normalization, prepare masked image
38
+ decoded_image = decoded_image * imagenet_std_tensor + imagenet_mean_tensor
39
+ input_image = image1 * imagenet_std_tensor + imagenet_mean_tensor
40
+ ref_image = image2 * imagenet_std_tensor + imagenet_mean_tensor
41
+ image_masks = model.unpatchify(model.patchify(torch.ones_like(ref_image)) * mask[:,:,None])
42
+ masked_input_image = ((1 - image_masks) * input_image)
43
+
44
+ # make visualization
45
+ visualization = torch.cat((ref_image, masked_input_image, decoded_image, input_image), dim=3) # 4*(B, 3, H, W) -> B, 3, H, W*4
46
+ B, C, H, W = visualization.shape
47
+ visualization = visualization.permute(1, 0, 2, 3).reshape(C, B*H, W)
48
+ visualization = torchvision.transforms.functional.to_pil_image(torch.clamp(visualization, 0, 1))
49
+ fname = "demo_output.png"
50
+ visualization.save(fname)
51
+ print('Visualization save in '+fname)
52
+
53
+
54
+ if __name__=="__main__":
55
+ main()
third_party/dust3r/croco/models/blocks.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Main encoder/decoder blocks
7
+ # --------------------------------------------------------
8
+ # References:
9
+ # timm
10
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
11
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/helpers.py
12
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
13
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/mlp.py
14
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/patch_embed.py
15
+
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from itertools import repeat
21
+ import collections.abc
22
+
23
+
24
+ def _ntuple(n):
25
+ def parse(x):
26
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
27
+ return x
28
+ return tuple(repeat(x, n))
29
+ return parse
30
+ to_2tuple = _ntuple(2)
31
+
32
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
33
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
34
+ """
35
+ if drop_prob == 0. or not training:
36
+ return x
37
+ keep_prob = 1 - drop_prob
38
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
39
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
40
+ if keep_prob > 0.0 and scale_by_keep:
41
+ random_tensor.div_(keep_prob)
42
+ return x * random_tensor
43
+
44
+ class DropPath(nn.Module):
45
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
46
+ """
47
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
48
+ super(DropPath, self).__init__()
49
+ self.drop_prob = drop_prob
50
+ self.scale_by_keep = scale_by_keep
51
+
52
+ def forward(self, x):
53
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
54
+
55
+ def extra_repr(self):
56
+ return f'drop_prob={round(self.drop_prob,3):0.3f}'
57
+
58
+ class Mlp(nn.Module):
59
+ """ MLP as used in Vision Transformer, MLP-Mixer and related networks"""
60
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
61
+ super().__init__()
62
+ out_features = out_features or in_features
63
+ hidden_features = hidden_features or in_features
64
+ bias = to_2tuple(bias)
65
+ drop_probs = to_2tuple(drop)
66
+
67
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
68
+ self.act = act_layer()
69
+ self.drop1 = nn.Dropout(drop_probs[0])
70
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
71
+ self.drop2 = nn.Dropout(drop_probs[1])
72
+
73
+ def forward(self, x):
74
+ x = self.fc1(x)
75
+ x = self.act(x)
76
+ x = self.drop1(x)
77
+ x = self.fc2(x)
78
+ x = self.drop2(x)
79
+ return x
80
+
81
+ class Attention(nn.Module):
82
+
83
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
84
+ super().__init__()
85
+ self.num_heads = num_heads
86
+ head_dim = dim // num_heads
87
+ self.scale = head_dim ** -0.5
88
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
89
+ self.attn_drop = nn.Dropout(attn_drop)
90
+ self.proj = nn.Linear(dim, dim)
91
+ self.proj_drop = nn.Dropout(proj_drop)
92
+ self.rope = rope
93
+
94
+ def forward(self, x, xpos):
95
+ B, N, C = x.shape
96
+
97
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).transpose(1,3)
98
+ q, k, v = [qkv[:,:,i] for i in range(3)]
99
+ # q,k,v = qkv.unbind(2) # make torchscript happy (cannot use tensor as tuple)
100
+
101
+ if self.rope is not None:
102
+ q = self.rope(q, xpos)
103
+ k = self.rope(k, xpos)
104
+
105
+ attn = (q @ k.transpose(-2, -1)) * self.scale
106
+ attn = attn.softmax(dim=-1)
107
+ attn = self.attn_drop(attn)
108
+
109
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
110
+ x = self.proj(x)
111
+ x = self.proj_drop(x)
112
+ return x
113
+
114
+ class Block(nn.Module):
115
+
116
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
117
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, rope=None):
118
+ super().__init__()
119
+ self.norm1 = norm_layer(dim)
120
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
121
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
122
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
123
+ self.norm2 = norm_layer(dim)
124
+ mlp_hidden_dim = int(dim * mlp_ratio)
125
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
126
+
127
+ def forward(self, x, xpos):
128
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
129
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
130
+ return x
131
+
132
+ class CrossAttention(nn.Module):
133
+
134
+ def __init__(self, dim, rope=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
135
+ super().__init__()
136
+ self.num_heads = num_heads
137
+ head_dim = dim // num_heads
138
+ self.scale = head_dim ** -0.5
139
+
140
+ self.projq = nn.Linear(dim, dim, bias=qkv_bias)
141
+ self.projk = nn.Linear(dim, dim, bias=qkv_bias)
142
+ self.projv = nn.Linear(dim, dim, bias=qkv_bias)
143
+ self.attn_drop = nn.Dropout(attn_drop)
144
+ self.proj = nn.Linear(dim, dim)
145
+ self.proj_drop = nn.Dropout(proj_drop)
146
+
147
+ self.rope = rope
148
+
149
+ def forward(self, query, key, value, qpos, kpos):
150
+ B, Nq, C = query.shape
151
+ Nk = key.shape[1]
152
+ Nv = value.shape[1]
153
+
154
+ q = self.projq(query).reshape(B,Nq,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
155
+ k = self.projk(key).reshape(B,Nk,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
156
+ v = self.projv(value).reshape(B,Nv,self.num_heads, C// self.num_heads).permute(0, 2, 1, 3)
157
+
158
+ if self.rope is not None:
159
+ q = self.rope(q, qpos)
160
+ k = self.rope(k, kpos)
161
+
162
+ attn = (q @ k.transpose(-2, -1)) * self.scale
163
+ attn = attn.softmax(dim=-1)
164
+ attn = self.attn_drop(attn)
165
+
166
+ x = (attn @ v).transpose(1, 2).reshape(B, Nq, C)
167
+ x = self.proj(x)
168
+ x = self.proj_drop(x)
169
+ return x
170
+
171
+ class DecoderBlock(nn.Module):
172
+
173
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
174
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, norm_mem=True, rope=None):
175
+ super().__init__()
176
+ self.norm1 = norm_layer(dim)
177
+ self.attn = Attention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
178
+ self.cross_attn = CrossAttention(dim, rope=rope, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
179
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
180
+ self.norm2 = norm_layer(dim)
181
+ self.norm3 = norm_layer(dim)
182
+ mlp_hidden_dim = int(dim * mlp_ratio)
183
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
184
+ self.norm_y = norm_layer(dim) if norm_mem else nn.Identity()
185
+
186
+ def forward(self, x, y, xpos, ypos):
187
+ x = x + self.drop_path(self.attn(self.norm1(x), xpos))
188
+ y_ = self.norm_y(y)
189
+ x = x + self.drop_path(self.cross_attn(self.norm2(x), y_, y_, xpos, ypos))
190
+ x = x + self.drop_path(self.mlp(self.norm3(x)))
191
+ return x, y
192
+
193
+
194
+ # patch embedding
195
+ class PositionGetter(object):
196
+ """ return positions of patches """
197
+
198
+ def __init__(self):
199
+ self.cache_positions = {}
200
+
201
+ def __call__(self, b, h, w, device):
202
+ if not (h,w) in self.cache_positions:
203
+ x = torch.arange(w, device=device)
204
+ y = torch.arange(h, device=device)
205
+ self.cache_positions[h,w] = torch.cartesian_prod(y, x) # (h, w, 2)
206
+ pos = self.cache_positions[h,w].view(1, h*w, 2).expand(b, -1, 2).clone()
207
+ return pos
208
+
209
+ class PatchEmbed(nn.Module):
210
+ """ just adding _init_weights + position getter compared to timm.models.layers.patch_embed.PatchEmbed"""
211
+
212
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
213
+ super().__init__()
214
+ img_size = to_2tuple(img_size)
215
+ patch_size = to_2tuple(patch_size)
216
+ self.img_size = img_size
217
+ self.patch_size = patch_size
218
+ self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
219
+ self.num_patches = self.grid_size[0] * self.grid_size[1]
220
+ self.flatten = flatten
221
+
222
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
223
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
224
+
225
+ self.position_getter = PositionGetter()
226
+
227
+ def forward(self, x):
228
+ B, C, H, W = x.shape
229
+ torch._assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).")
230
+ torch._assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).")
231
+ x = self.proj(x)
232
+ pos = self.position_getter(B, x.size(2), x.size(3), x.device)
233
+ if self.flatten:
234
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
235
+ x = self.norm(x)
236
+ return x, pos
237
+
238
+ def _init_weights(self):
239
+ w = self.proj.weight.data
240
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
241
+
third_party/dust3r/croco/models/criterion.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Criterion to train CroCo
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # MAE: https://github.com/facebookresearch/mae
9
+ # --------------------------------------------------------
10
+
11
+ import torch
12
+
13
+ class MaskedMSE(torch.nn.Module):
14
+
15
+ def __init__(self, norm_pix_loss=False, masked=True):
16
+ """
17
+ norm_pix_loss: normalize each patch by their pixel mean and variance
18
+ masked: compute loss over the masked patches only
19
+ """
20
+ super().__init__()
21
+ self.norm_pix_loss = norm_pix_loss
22
+ self.masked = masked
23
+
24
+ def forward(self, pred, mask, target):
25
+
26
+ if self.norm_pix_loss:
27
+ mean = target.mean(dim=-1, keepdim=True)
28
+ var = target.var(dim=-1, keepdim=True)
29
+ target = (target - mean) / (var + 1.e-6)**.5
30
+
31
+ loss = (pred - target) ** 2
32
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
33
+ if self.masked:
34
+ loss = (loss * mask).sum() / mask.sum() # mean loss on masked patches
35
+ else:
36
+ loss = loss.mean() # mean loss
37
+ return loss
third_party/dust3r/croco/models/croco.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # CroCo model during pretraining
7
+ # --------------------------------------------------------
8
+
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
14
+ from functools import partial
15
+
16
+ from models.blocks import Block, DecoderBlock, PatchEmbed
17
+ from models.pos_embed import get_2d_sincos_pos_embed, RoPE2D
18
+ from models.masking import RandomMask
19
+
20
+
21
+ class CroCoNet(nn.Module):
22
+
23
+ def __init__(self,
24
+ img_size=224, # input image size
25
+ patch_size=16, # patch_size
26
+ mask_ratio=0.9, # ratios of masked tokens
27
+ enc_embed_dim=768, # encoder feature dimension
28
+ enc_depth=12, # encoder depth
29
+ enc_num_heads=12, # encoder number of heads in the transformer block
30
+ dec_embed_dim=512, # decoder feature dimension
31
+ dec_depth=8, # decoder depth
32
+ dec_num_heads=16, # decoder number of heads in the transformer block
33
+ mlp_ratio=4,
34
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
35
+ norm_im2_in_dec=True, # whether to apply normalization of the 'memory' = (second image) in the decoder
36
+ pos_embed='cosine', # positional embedding (either cosine or RoPE100)
37
+ ):
38
+
39
+ super(CroCoNet, self).__init__()
40
+
41
+ # patch embeddings (with initialization done as in MAE)
42
+ self._set_patch_embed(img_size, patch_size, enc_embed_dim)
43
+
44
+ # mask generations
45
+ self._set_mask_generator(self.patch_embed.num_patches, mask_ratio)
46
+
47
+ self.pos_embed = pos_embed
48
+ if pos_embed=='cosine':
49
+ # positional embedding of the encoder
50
+ enc_pos_embed = get_2d_sincos_pos_embed(enc_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
51
+ self.register_buffer('enc_pos_embed', torch.from_numpy(enc_pos_embed).float())
52
+ # positional embedding of the decoder
53
+ dec_pos_embed = get_2d_sincos_pos_embed(dec_embed_dim, int(self.patch_embed.num_patches**.5), n_cls_token=0)
54
+ self.register_buffer('dec_pos_embed', torch.from_numpy(dec_pos_embed).float())
55
+ # pos embedding in each block
56
+ self.rope = None # nothing for cosine
57
+ elif pos_embed.startswith('RoPE'): # eg RoPE100
58
+ self.enc_pos_embed = None # nothing to add in the encoder with RoPE
59
+ self.dec_pos_embed = None # nothing to add in the decoder with RoPE
60
+ if RoPE2D is None: raise ImportError("Cannot find cuRoPE2D, please install it following the README instructions")
61
+ freq = float(pos_embed[len('RoPE'):])
62
+ self.rope = RoPE2D(freq=freq)
63
+ else:
64
+ raise NotImplementedError('Unknown pos_embed '+pos_embed)
65
+
66
+ # transformer for the encoder
67
+ self.enc_depth = enc_depth
68
+ self.enc_embed_dim = enc_embed_dim
69
+ self.enc_blocks = nn.ModuleList([
70
+ Block(enc_embed_dim, enc_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, rope=self.rope)
71
+ for i in range(enc_depth)])
72
+ self.enc_norm = norm_layer(enc_embed_dim)
73
+
74
+ # masked tokens
75
+ self._set_mask_token(dec_embed_dim)
76
+
77
+ # decoder
78
+ self._set_decoder(enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec)
79
+
80
+ # prediction head
81
+ self._set_prediction_head(dec_embed_dim, patch_size)
82
+
83
+ # initializer weights
84
+ self.initialize_weights()
85
+
86
+ def _set_patch_embed(self, img_size=224, patch_size=16, enc_embed_dim=768):
87
+ self.patch_embed = PatchEmbed(img_size, patch_size, 3, enc_embed_dim)
88
+
89
+ def _set_mask_generator(self, num_patches, mask_ratio):
90
+ self.mask_generator = RandomMask(num_patches, mask_ratio)
91
+
92
+ def _set_mask_token(self, dec_embed_dim):
93
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, dec_embed_dim))
94
+
95
+ def _set_decoder(self, enc_embed_dim, dec_embed_dim, dec_num_heads, dec_depth, mlp_ratio, norm_layer, norm_im2_in_dec):
96
+ self.dec_depth = dec_depth
97
+ self.dec_embed_dim = dec_embed_dim
98
+ # transfer from encoder to decoder
99
+ self.decoder_embed = nn.Linear(enc_embed_dim, dec_embed_dim, bias=True)
100
+ # transformer for the decoder
101
+ self.dec_blocks = nn.ModuleList([
102
+ DecoderBlock(dec_embed_dim, dec_num_heads, mlp_ratio=mlp_ratio, qkv_bias=True, norm_layer=norm_layer, norm_mem=norm_im2_in_dec, rope=self.rope)
103
+ for i in range(dec_depth)])
104
+ # final norm layer
105
+ self.dec_norm = norm_layer(dec_embed_dim)
106
+
107
+ def _set_prediction_head(self, dec_embed_dim, patch_size):
108
+ self.prediction_head = nn.Linear(dec_embed_dim, patch_size**2 * 3, bias=True)
109
+
110
+
111
+ def initialize_weights(self):
112
+ # patch embed
113
+ self.patch_embed._init_weights()
114
+ # mask tokens
115
+ if self.mask_token is not None: torch.nn.init.normal_(self.mask_token, std=.02)
116
+ # linears and layer norms
117
+ self.apply(self._init_weights)
118
+
119
+ def _init_weights(self, m):
120
+ if isinstance(m, nn.Linear):
121
+ # we use xavier_uniform following official JAX ViT:
122
+ torch.nn.init.xavier_uniform_(m.weight)
123
+ if isinstance(m, nn.Linear) and m.bias is not None:
124
+ nn.init.constant_(m.bias, 0)
125
+ elif isinstance(m, nn.LayerNorm):
126
+ nn.init.constant_(m.bias, 0)
127
+ nn.init.constant_(m.weight, 1.0)
128
+
129
+ def _encode_image(self, image, do_mask=False, return_all_blocks=False):
130
+ """
131
+ image has B x 3 x img_size x img_size
132
+ do_mask: whether to perform masking or not
133
+ return_all_blocks: if True, return the features at the end of every block
134
+ instead of just the features from the last block (eg for some prediction heads)
135
+ """
136
+ # embed the image into patches (x has size B x Npatches x C)
137
+ # and get position if each return patch (pos has size B x Npatches x 2)
138
+ x, pos = self.patch_embed(image)
139
+ # add positional embedding without cls token
140
+ if self.enc_pos_embed is not None:
141
+ x = x + self.enc_pos_embed[None,...]
142
+ # apply masking
143
+ B,N,C = x.size()
144
+ if do_mask:
145
+ masks = self.mask_generator(x)
146
+ x = x[~masks].view(B, -1, C)
147
+ posvis = pos[~masks].view(B, -1, 2)
148
+ else:
149
+ B,N,C = x.size()
150
+ masks = torch.zeros((B,N), dtype=bool)
151
+ posvis = pos
152
+ # now apply the transformer encoder and normalization
153
+ if return_all_blocks:
154
+ out = []
155
+ for blk in self.enc_blocks:
156
+ x = blk(x, posvis)
157
+ out.append(x)
158
+ out[-1] = self.enc_norm(out[-1])
159
+ return out, pos, masks
160
+ else:
161
+ for blk in self.enc_blocks:
162
+ x = blk(x, posvis)
163
+ x = self.enc_norm(x)
164
+ return x, pos, masks
165
+
166
+ def _decoder(self, feat1, pos1, masks1, feat2, pos2, return_all_blocks=False):
167
+ """
168
+ return_all_blocks: if True, return the features at the end of every block
169
+ instead of just the features from the last block (eg for some prediction heads)
170
+
171
+ masks1 can be None => assume image1 fully visible
172
+ """
173
+ # encoder to decoder layer
174
+ visf1 = self.decoder_embed(feat1)
175
+ f2 = self.decoder_embed(feat2)
176
+ # append masked tokens to the sequence
177
+ B,Nenc,C = visf1.size()
178
+ if masks1 is None: # downstreams
179
+ f1_ = visf1
180
+ else: # pretraining
181
+ Ntotal = masks1.size(1)
182
+ f1_ = self.mask_token.repeat(B, Ntotal, 1).to(dtype=visf1.dtype)
183
+ f1_[~masks1] = visf1.view(B * Nenc, C)
184
+ # add positional embedding
185
+ if self.dec_pos_embed is not None:
186
+ f1_ = f1_ + self.dec_pos_embed
187
+ f2 = f2 + self.dec_pos_embed
188
+ # apply Transformer blocks
189
+ out = f1_
190
+ out2 = f2
191
+ if return_all_blocks:
192
+ _out, out = out, []
193
+ for blk in self.dec_blocks:
194
+ _out, out2 = blk(_out, out2, pos1, pos2)
195
+ out.append(_out)
196
+ out[-1] = self.dec_norm(out[-1])
197
+ else:
198
+ for blk in self.dec_blocks:
199
+ out, out2 = blk(out, out2, pos1, pos2)
200
+ out = self.dec_norm(out)
201
+ return out
202
+
203
+ def patchify(self, imgs):
204
+ """
205
+ imgs: (B, 3, H, W)
206
+ x: (B, L, patch_size**2 *3)
207
+ """
208
+ p = self.patch_embed.patch_size[0]
209
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
210
+
211
+ h = w = imgs.shape[2] // p
212
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
213
+ x = torch.einsum('nchpwq->nhwpqc', x)
214
+ x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
215
+
216
+ return x
217
+
218
+ def unpatchify(self, x, channels=3):
219
+ """
220
+ x: (N, L, patch_size**2 *channels)
221
+ imgs: (N, 3, H, W)
222
+ """
223
+ patch_size = self.patch_embed.patch_size[0]
224
+ h = w = int(x.shape[1]**.5)
225
+ assert h * w == x.shape[1]
226
+ x = x.reshape(shape=(x.shape[0], h, w, patch_size, patch_size, channels))
227
+ x = torch.einsum('nhwpqc->nchpwq', x)
228
+ imgs = x.reshape(shape=(x.shape[0], channels, h * patch_size, h * patch_size))
229
+ return imgs
230
+
231
+ def forward(self, img1, img2):
232
+ """
233
+ img1: tensor of size B x 3 x img_size x img_size
234
+ img2: tensor of size B x 3 x img_size x img_size
235
+
236
+ out will be B x N x (3*patch_size*patch_size)
237
+ masks are also returned as B x N just in case
238
+ """
239
+ # encoder of the masked first image
240
+ feat1, pos1, mask1 = self._encode_image(img1, do_mask=True)
241
+ # encoder of the second image
242
+ feat2, pos2, _ = self._encode_image(img2, do_mask=False)
243
+ # decoder
244
+ decfeat = self._decoder(feat1, pos1, mask1, feat2, pos2)
245
+ # prediction head
246
+ out = self.prediction_head(decfeat)
247
+ # get target
248
+ target = self.patchify(img1)
249
+ return out, mask1, target
third_party/dust3r/croco/models/croco_downstream.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # CroCo model for downstream tasks
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+
10
+ from .croco import CroCoNet
11
+
12
+
13
+ def croco_args_from_ckpt(ckpt):
14
+ if 'croco_kwargs' in ckpt: # CroCo v2 released models
15
+ return ckpt['croco_kwargs']
16
+ elif 'args' in ckpt and hasattr(ckpt['args'], 'model'): # pretrained using the official code release
17
+ s = ckpt['args'].model # eg "CroCoNet(enc_embed_dim=1024, enc_num_heads=16, enc_depth=24)"
18
+ assert s.startswith('CroCoNet(')
19
+ return eval('dict'+s[len('CroCoNet'):]) # transform it into the string of a dictionary and evaluate it
20
+ else: # CroCo v1 released models
21
+ return dict()
22
+
23
+ class CroCoDownstreamMonocularEncoder(CroCoNet):
24
+
25
+ def __init__(self,
26
+ head,
27
+ **kwargs):
28
+ """ Build network for monocular downstream task, only using the encoder.
29
+ It takes an extra argument head, that is called with the features
30
+ and a dictionary img_info containing 'width' and 'height' keys
31
+ The head is setup with the croconet arguments in this init function
32
+ NOTE: It works by *calling super().__init__() but with redefined setters
33
+
34
+ """
35
+ super(CroCoDownstreamMonocularEncoder, self).__init__(**kwargs)
36
+ head.setup(self)
37
+ self.head = head
38
+
39
+ def _set_mask_generator(self, *args, **kwargs):
40
+ """ No mask generator """
41
+ return
42
+
43
+ def _set_mask_token(self, *args, **kwargs):
44
+ """ No mask token """
45
+ self.mask_token = None
46
+ return
47
+
48
+ def _set_decoder(self, *args, **kwargs):
49
+ """ No decoder """
50
+ return
51
+
52
+ def _set_prediction_head(self, *args, **kwargs):
53
+ """ No 'prediction head' for downstream tasks."""
54
+ return
55
+
56
+ def forward(self, img):
57
+ """
58
+ img if of size batch_size x 3 x h x w
59
+ """
60
+ B, C, H, W = img.size()
61
+ img_info = {'height': H, 'width': W}
62
+ need_all_layers = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks
63
+ out, _, _ = self._encode_image(img, do_mask=False, return_all_blocks=need_all_layers)
64
+ return self.head(out, img_info)
65
+
66
+
67
+ class CroCoDownstreamBinocular(CroCoNet):
68
+
69
+ def __init__(self,
70
+ head,
71
+ **kwargs):
72
+ """ Build network for binocular downstream task
73
+ It takes an extra argument head, that is called with the features
74
+ and a dictionary img_info containing 'width' and 'height' keys
75
+ The head is setup with the croconet arguments in this init function
76
+ """
77
+ super(CroCoDownstreamBinocular, self).__init__(**kwargs)
78
+ head.setup(self)
79
+ self.head = head
80
+
81
+ def _set_mask_generator(self, *args, **kwargs):
82
+ """ No mask generator """
83
+ return
84
+
85
+ def _set_mask_token(self, *args, **kwargs):
86
+ """ No mask token """
87
+ self.mask_token = None
88
+ return
89
+
90
+ def _set_prediction_head(self, *args, **kwargs):
91
+ """ No prediction head for downstream tasks, define your own head """
92
+ return
93
+
94
+ def encode_image_pairs(self, img1, img2, return_all_blocks=False):
95
+ """ run encoder for a pair of images
96
+ it is actually ~5% faster to concatenate the images along the batch dimension
97
+ than to encode them separately
98
+ """
99
+ ## the two commented lines below is the naive version with separate encoding
100
+ #out, pos, _ = self._encode_image(img1, do_mask=False, return_all_blocks=return_all_blocks)
101
+ #out2, pos2, _ = self._encode_image(img2, do_mask=False, return_all_blocks=False)
102
+ ## and now the faster version
103
+ out, pos, _ = self._encode_image( torch.cat( (img1,img2), dim=0), do_mask=False, return_all_blocks=return_all_blocks )
104
+ if return_all_blocks:
105
+ out,out2 = list(map(list, zip(*[o.chunk(2, dim=0) for o in out])))
106
+ out2 = out2[-1]
107
+ else:
108
+ out,out2 = out.chunk(2, dim=0)
109
+ pos,pos2 = pos.chunk(2, dim=0)
110
+ return out, out2, pos, pos2
111
+
112
+ def forward(self, img1, img2):
113
+ B, C, H, W = img1.size()
114
+ img_info = {'height': H, 'width': W}
115
+ return_all_blocks = hasattr(self.head, 'return_all_blocks') and self.head.return_all_blocks
116
+ out, out2, pos, pos2 = self.encode_image_pairs(img1, img2, return_all_blocks=return_all_blocks)
117
+ if return_all_blocks:
118
+ decout = self._decoder(out[-1], pos, None, out2, pos2, return_all_blocks=return_all_blocks)
119
+ decout = out+decout
120
+ else:
121
+ decout = self._decoder(out, pos, None, out2, pos2, return_all_blocks=return_all_blocks)
122
+ return self.head(decout, img_info)
third_party/dust3r/croco/models/curope/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ from .curope2d import cuRoPE2D
third_party/dust3r/croco/models/curope/curope.cpp ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (C) 2022-present Naver Corporation. All rights reserved.
3
+ Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ */
5
+
6
+ #include <torch/extension.h>
7
+
8
+ // forward declaration
9
+ void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd );
10
+
11
+ void rope_2d_cpu( torch::Tensor tokens, const torch::Tensor positions, const float base, const float fwd )
12
+ {
13
+ const int B = tokens.size(0);
14
+ const int N = tokens.size(1);
15
+ const int H = tokens.size(2);
16
+ const int D = tokens.size(3) / 4;
17
+
18
+ auto tok = tokens.accessor<float, 4>();
19
+ auto pos = positions.accessor<int64_t, 3>();
20
+
21
+ for (int b = 0; b < B; b++) {
22
+ for (int x = 0; x < 2; x++) { // y and then x (2d)
23
+ for (int n = 0; n < N; n++) {
24
+
25
+ // grab the token position
26
+ const int p = pos[b][n][x];
27
+
28
+ for (int h = 0; h < H; h++) {
29
+ for (int d = 0; d < D; d++) {
30
+ // grab the two values
31
+ float u = tok[b][n][h][d+0+x*2*D];
32
+ float v = tok[b][n][h][d+D+x*2*D];
33
+
34
+ // grab the cos,sin
35
+ const float inv_freq = fwd * p / powf(base, d/float(D));
36
+ float c = cosf(inv_freq);
37
+ float s = sinf(inv_freq);
38
+
39
+ // write the result
40
+ tok[b][n][h][d+0+x*2*D] = u*c - v*s;
41
+ tok[b][n][h][d+D+x*2*D] = v*c + u*s;
42
+ }
43
+ }
44
+ }
45
+ }
46
+ }
47
+ }
48
+
49
+ void rope_2d( torch::Tensor tokens, // B,N,H,D
50
+ const torch::Tensor positions, // B,N,2
51
+ const float base,
52
+ const float fwd )
53
+ {
54
+ TORCH_CHECK(tokens.dim() == 4, "tokens must have 4 dimensions");
55
+ TORCH_CHECK(positions.dim() == 3, "positions must have 3 dimensions");
56
+ TORCH_CHECK(tokens.size(0) == positions.size(0), "batch size differs between tokens & positions");
57
+ TORCH_CHECK(tokens.size(1) == positions.size(1), "seq_length differs between tokens & positions");
58
+ TORCH_CHECK(positions.size(2) == 2, "positions.shape[2] must be equal to 2");
59
+ TORCH_CHECK(tokens.is_cuda() == positions.is_cuda(), "tokens and positions are not on the same device" );
60
+
61
+ if (tokens.is_cuda())
62
+ rope_2d_cuda( tokens, positions, base, fwd );
63
+ else
64
+ rope_2d_cpu( tokens, positions, base, fwd );
65
+ }
66
+
67
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
68
+ m.def("rope_2d", &rope_2d, "RoPE 2d forward/backward");
69
+ }
third_party/dust3r/croco/models/curope/curope2d.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ import torch
5
+
6
+ try:
7
+ import curope as _kernels # run `python setup.py install`
8
+ except ModuleNotFoundError:
9
+ from . import curope as _kernels # run `python setup.py build_ext --inplace`
10
+
11
+
12
+ class cuRoPE2D_func (torch.autograd.Function):
13
+
14
+ @staticmethod
15
+ def forward(ctx, tokens, positions, base, F0=1):
16
+ ctx.save_for_backward(positions)
17
+ ctx.saved_base = base
18
+ ctx.saved_F0 = F0
19
+ # tokens = tokens.clone() # uncomment this if inplace doesn't work
20
+ _kernels.rope_2d( tokens, positions, base, F0 )
21
+ ctx.mark_dirty(tokens)
22
+ return tokens
23
+
24
+ @staticmethod
25
+ def backward(ctx, grad_res):
26
+ positions, base, F0 = ctx.saved_tensors[0], ctx.saved_base, ctx.saved_F0
27
+ _kernels.rope_2d( grad_res, positions, base, -F0 )
28
+ ctx.mark_dirty(grad_res)
29
+ return grad_res, None, None, None
30
+
31
+
32
+ class cuRoPE2D(torch.nn.Module):
33
+ def __init__(self, freq=100.0, F0=1.0):
34
+ super().__init__()
35
+ self.base = freq
36
+ self.F0 = F0
37
+
38
+ def forward(self, tokens, positions):
39
+ cuRoPE2D_func.apply( tokens.transpose(1,2), positions, self.base, self.F0 )
40
+ return tokens
third_party/dust3r/croco/models/curope/kernels.cu ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (C) 2022-present Naver Corporation. All rights reserved.
3
+ Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
4
+ */
5
+
6
+ #include <torch/extension.h>
7
+ #include <cuda.h>
8
+ #include <cuda_runtime.h>
9
+ #include <vector>
10
+
11
+ #define CHECK_CUDA(tensor) {\
12
+ TORCH_CHECK((tensor).is_cuda(), #tensor " is not in cuda memory"); \
13
+ TORCH_CHECK((tensor).is_contiguous(), #tensor " is not contiguous"); }
14
+ void CHECK_KERNEL() {auto error = cudaGetLastError(); TORCH_CHECK( error == cudaSuccess, cudaGetErrorString(error));}
15
+
16
+
17
+ template < typename scalar_t >
18
+ __global__ void rope_2d_cuda_kernel(
19
+ //scalar_t* __restrict__ tokens,
20
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> tokens,
21
+ const int64_t* __restrict__ pos,
22
+ const float base,
23
+ const float fwd )
24
+ // const int N, const int H, const int D )
25
+ {
26
+ // tokens shape = (B, N, H, D)
27
+ const int N = tokens.size(1);
28
+ const int H = tokens.size(2);
29
+ const int D = tokens.size(3);
30
+
31
+ // each block update a single token, for all heads
32
+ // each thread takes care of a single output
33
+ extern __shared__ float shared[];
34
+ float* shared_inv_freq = shared + D;
35
+
36
+ const int b = blockIdx.x / N;
37
+ const int n = blockIdx.x % N;
38
+
39
+ const int Q = D / 4;
40
+ // one token = [0..Q : Q..2Q : 2Q..3Q : 3Q..D]
41
+ // u_Y v_Y u_X v_X
42
+
43
+ // shared memory: first, compute inv_freq
44
+ if (threadIdx.x < Q)
45
+ shared_inv_freq[threadIdx.x] = fwd / powf(base, threadIdx.x/float(Q));
46
+ __syncthreads();
47
+
48
+ // start of X or Y part
49
+ const int X = threadIdx.x < D/2 ? 0 : 1;
50
+ const int m = (X*D/2) + (threadIdx.x % Q); // index of u_Y or u_X
51
+
52
+ // grab the cos,sin appropriate for me
53
+ const float freq = pos[blockIdx.x*2+X] * shared_inv_freq[threadIdx.x % Q];
54
+ const float cos = cosf(freq);
55
+ const float sin = sinf(freq);
56
+ /*
57
+ float* shared_cos_sin = shared + D + D/4;
58
+ if ((threadIdx.x % (D/2)) < Q)
59
+ shared_cos_sin[m+0] = cosf(freq);
60
+ else
61
+ shared_cos_sin[m+Q] = sinf(freq);
62
+ __syncthreads();
63
+ const float cos = shared_cos_sin[m+0];
64
+ const float sin = shared_cos_sin[m+Q];
65
+ */
66
+
67
+ for (int h = 0; h < H; h++)
68
+ {
69
+ // then, load all the token for this head in shared memory
70
+ shared[threadIdx.x] = tokens[b][n][h][threadIdx.x];
71
+ __syncthreads();
72
+
73
+ const float u = shared[m];
74
+ const float v = shared[m+Q];
75
+
76
+ // write output
77
+ if ((threadIdx.x % (D/2)) < Q)
78
+ tokens[b][n][h][threadIdx.x] = u*cos - v*sin;
79
+ else
80
+ tokens[b][n][h][threadIdx.x] = v*cos + u*sin;
81
+ }
82
+ }
83
+
84
+ void rope_2d_cuda( torch::Tensor tokens, const torch::Tensor pos, const float base, const float fwd )
85
+ {
86
+ const int B = tokens.size(0); // batch size
87
+ const int N = tokens.size(1); // sequence length
88
+ const int H = tokens.size(2); // number of heads
89
+ const int D = tokens.size(3); // dimension per head
90
+
91
+ TORCH_CHECK(tokens.stride(3) == 1 && tokens.stride(2) == D, "tokens are not contiguous");
92
+ TORCH_CHECK(pos.is_contiguous(), "positions are not contiguous");
93
+ TORCH_CHECK(pos.size(0) == B && pos.size(1) == N && pos.size(2) == 2, "bad pos.shape");
94
+ TORCH_CHECK(D % 4 == 0, "token dim must be multiple of 4");
95
+
96
+ // one block for each layer, one thread per local-max
97
+ const int THREADS_PER_BLOCK = D;
98
+ const int N_BLOCKS = B * N; // each block takes care of H*D values
99
+ const int SHARED_MEM = sizeof(float) * (D + D/4);
100
+
101
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(tokens.type(), "rope_2d_cuda", ([&] {
102
+ rope_2d_cuda_kernel<scalar_t> <<<N_BLOCKS, THREADS_PER_BLOCK, SHARED_MEM>>> (
103
+ //tokens.data_ptr<scalar_t>(),
104
+ tokens.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
105
+ pos.data_ptr<int64_t>(),
106
+ base, fwd); //, N, H, D );
107
+ }));
108
+ }
third_party/dust3r/croco/models/curope/setup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ from setuptools import setup
5
+ from torch import cuda
6
+ from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7
+
8
+ # compile for all possible CUDA architectures
9
+ all_cuda_archs = cuda.get_gencode_flags().replace('compute=','arch=').split()
10
+ # alternatively, you can list cuda archs that you want, eg:
11
+ # all_cuda_archs = [
12
+ # '-gencode', 'arch=compute_70,code=sm_70',
13
+ # '-gencode', 'arch=compute_75,code=sm_75',
14
+ # '-gencode', 'arch=compute_80,code=sm_80',
15
+ # '-gencode', 'arch=compute_86,code=sm_86'
16
+ # ]
17
+
18
+ setup(
19
+ name = 'curope',
20
+ ext_modules = [
21
+ CUDAExtension(
22
+ name='curope',
23
+ sources=[
24
+ "curope.cpp",
25
+ "kernels.cu",
26
+ ],
27
+ extra_compile_args = dict(
28
+ nvcc=['-O3','--ptxas-options=-v',"--use_fast_math"]+all_cuda_archs,
29
+ cxx=['-O3'])
30
+ )
31
+ ],
32
+ cmdclass = {
33
+ 'build_ext': BuildExtension
34
+ })
third_party/dust3r/croco/models/dpt_block.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # DPT head for ViTs
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # https://github.com/isl-org/DPT
9
+ # https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/output_adapters.py
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ from einops import rearrange, repeat
15
+ from typing import Union, Tuple, Iterable, List, Optional, Dict
16
+
17
+ def pair(t):
18
+ return t if isinstance(t, tuple) else (t, t)
19
+
20
+ def make_scratch(in_shape, out_shape, groups=1, expand=False):
21
+ scratch = nn.Module()
22
+
23
+ out_shape1 = out_shape
24
+ out_shape2 = out_shape
25
+ out_shape3 = out_shape
26
+ out_shape4 = out_shape
27
+ if expand == True:
28
+ out_shape1 = out_shape
29
+ out_shape2 = out_shape * 2
30
+ out_shape3 = out_shape * 4
31
+ out_shape4 = out_shape * 8
32
+
33
+ scratch.layer1_rn = nn.Conv2d(
34
+ in_shape[0],
35
+ out_shape1,
36
+ kernel_size=3,
37
+ stride=1,
38
+ padding=1,
39
+ bias=False,
40
+ groups=groups,
41
+ )
42
+ scratch.layer2_rn = nn.Conv2d(
43
+ in_shape[1],
44
+ out_shape2,
45
+ kernel_size=3,
46
+ stride=1,
47
+ padding=1,
48
+ bias=False,
49
+ groups=groups,
50
+ )
51
+ scratch.layer3_rn = nn.Conv2d(
52
+ in_shape[2],
53
+ out_shape3,
54
+ kernel_size=3,
55
+ stride=1,
56
+ padding=1,
57
+ bias=False,
58
+ groups=groups,
59
+ )
60
+ scratch.layer4_rn = nn.Conv2d(
61
+ in_shape[3],
62
+ out_shape4,
63
+ kernel_size=3,
64
+ stride=1,
65
+ padding=1,
66
+ bias=False,
67
+ groups=groups,
68
+ )
69
+
70
+ scratch.layer_rn = nn.ModuleList([
71
+ scratch.layer1_rn,
72
+ scratch.layer2_rn,
73
+ scratch.layer3_rn,
74
+ scratch.layer4_rn,
75
+ ])
76
+
77
+ return scratch
78
+
79
+ class ResidualConvUnit_custom(nn.Module):
80
+ """Residual convolution module."""
81
+
82
+ def __init__(self, features, activation, bn):
83
+ """Init.
84
+ Args:
85
+ features (int): number of features
86
+ """
87
+ super().__init__()
88
+
89
+ self.bn = bn
90
+
91
+ self.groups = 1
92
+
93
+ self.conv1 = nn.Conv2d(
94
+ features,
95
+ features,
96
+ kernel_size=3,
97
+ stride=1,
98
+ padding=1,
99
+ bias=not self.bn,
100
+ groups=self.groups,
101
+ )
102
+
103
+ self.conv2 = nn.Conv2d(
104
+ features,
105
+ features,
106
+ kernel_size=3,
107
+ stride=1,
108
+ padding=1,
109
+ bias=not self.bn,
110
+ groups=self.groups,
111
+ )
112
+
113
+ if self.bn == True:
114
+ self.bn1 = nn.BatchNorm2d(features)
115
+ self.bn2 = nn.BatchNorm2d(features)
116
+
117
+ self.activation = activation
118
+
119
+ self.skip_add = nn.quantized.FloatFunctional()
120
+
121
+ def forward(self, x):
122
+ """Forward pass.
123
+ Args:
124
+ x (tensor): input
125
+ Returns:
126
+ tensor: output
127
+ """
128
+
129
+ out = self.activation(x)
130
+ out = self.conv1(out)
131
+ if self.bn == True:
132
+ out = self.bn1(out)
133
+
134
+ out = self.activation(out)
135
+ out = self.conv2(out)
136
+ if self.bn == True:
137
+ out = self.bn2(out)
138
+
139
+ if self.groups > 1:
140
+ out = self.conv_merge(out)
141
+
142
+ return self.skip_add.add(out, x)
143
+
144
+ class FeatureFusionBlock_custom(nn.Module):
145
+ """Feature fusion block."""
146
+
147
+ def __init__(
148
+ self,
149
+ features,
150
+ activation,
151
+ deconv=False,
152
+ bn=False,
153
+ expand=False,
154
+ align_corners=True,
155
+ width_ratio=1,
156
+ ):
157
+ """Init.
158
+ Args:
159
+ features (int): number of features
160
+ """
161
+ super(FeatureFusionBlock_custom, self).__init__()
162
+ self.width_ratio = width_ratio
163
+
164
+ self.deconv = deconv
165
+ self.align_corners = align_corners
166
+
167
+ self.groups = 1
168
+
169
+ self.expand = expand
170
+ out_features = features
171
+ if self.expand == True:
172
+ out_features = features // 2
173
+
174
+ self.out_conv = nn.Conv2d(
175
+ features,
176
+ out_features,
177
+ kernel_size=1,
178
+ stride=1,
179
+ padding=0,
180
+ bias=True,
181
+ groups=1,
182
+ )
183
+
184
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
185
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
186
+
187
+ self.skip_add = nn.quantized.FloatFunctional()
188
+
189
+ def forward(self, *xs):
190
+ """Forward pass.
191
+ Returns:
192
+ tensor: output
193
+ """
194
+ output = xs[0]
195
+
196
+ if len(xs) == 2:
197
+ res = self.resConfUnit1(xs[1])
198
+ if self.width_ratio != 1:
199
+ res = F.interpolate(res, size=(output.shape[2], output.shape[3]), mode='bilinear')
200
+
201
+ output = self.skip_add.add(output, res)
202
+ # output += res
203
+
204
+ output = self.resConfUnit2(output)
205
+
206
+ if self.width_ratio != 1:
207
+ # and output.shape[3] < self.width_ratio * output.shape[2]
208
+ #size=(image.shape[])
209
+ if (output.shape[3] / output.shape[2]) < (2 / 3) * self.width_ratio:
210
+ shape = 3 * output.shape[3]
211
+ else:
212
+ shape = int(self.width_ratio * 2 * output.shape[2])
213
+ output = F.interpolate(output, size=(2* output.shape[2], shape), mode='bilinear')
214
+ else:
215
+ output = nn.functional.interpolate(output, scale_factor=2,
216
+ mode="bilinear", align_corners=self.align_corners)
217
+ output = self.out_conv(output)
218
+ return output
219
+
220
+ def make_fusion_block(features, use_bn, width_ratio=1):
221
+ return FeatureFusionBlock_custom(
222
+ features,
223
+ nn.ReLU(False),
224
+ deconv=False,
225
+ bn=use_bn,
226
+ expand=False,
227
+ align_corners=True,
228
+ width_ratio=width_ratio,
229
+ )
230
+
231
+ class Interpolate(nn.Module):
232
+ """Interpolation module."""
233
+
234
+ def __init__(self, scale_factor, mode, align_corners=False):
235
+ """Init.
236
+ Args:
237
+ scale_factor (float): scaling
238
+ mode (str): interpolation mode
239
+ """
240
+ super(Interpolate, self).__init__()
241
+
242
+ self.interp = nn.functional.interpolate
243
+ self.scale_factor = scale_factor
244
+ self.mode = mode
245
+ self.align_corners = align_corners
246
+
247
+ def forward(self, x):
248
+ """Forward pass.
249
+ Args:
250
+ x (tensor): input
251
+ Returns:
252
+ tensor: interpolated data
253
+ """
254
+
255
+ x = self.interp(
256
+ x,
257
+ scale_factor=self.scale_factor,
258
+ mode=self.mode,
259
+ align_corners=self.align_corners,
260
+ )
261
+
262
+ return x
263
+
264
+ class DPTOutputAdapter(nn.Module):
265
+ """DPT output adapter.
266
+
267
+ :param num_cahnnels: Number of output channels
268
+ :param stride_level: tride level compared to the full-sized image.
269
+ E.g. 4 for 1/4th the size of the image.
270
+ :param patch_size_full: Int or tuple of the patch size over the full image size.
271
+ Patch size for smaller inputs will be computed accordingly.
272
+ :param hooks: Index of intermediate layers
273
+ :param layer_dims: Dimension of intermediate layers
274
+ :param feature_dim: Feature dimension
275
+ :param last_dim: out_channels/in_channels for the last two Conv2d when head_type == regression
276
+ :param use_bn: If set to True, activates batch norm
277
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
278
+ """
279
+
280
+ def __init__(self,
281
+ num_channels: int = 1,
282
+ stride_level: int = 1,
283
+ patch_size: Union[int, Tuple[int, int]] = 16,
284
+ main_tasks: Iterable[str] = ('rgb',),
285
+ hooks: List[int] = [2, 5, 8, 11],
286
+ layer_dims: List[int] = [96, 192, 384, 768],
287
+ feature_dim: int = 256,
288
+ last_dim: int = 32,
289
+ use_bn: bool = False,
290
+ dim_tokens_enc: Optional[int] = None,
291
+ head_type: str = 'regression',
292
+ output_width_ratio=1,
293
+ **kwargs):
294
+ super().__init__()
295
+ self.num_channels = num_channels
296
+ self.stride_level = stride_level
297
+ self.patch_size = pair(patch_size)
298
+ self.main_tasks = main_tasks
299
+ self.hooks = hooks
300
+ self.layer_dims = layer_dims
301
+ self.feature_dim = feature_dim
302
+ self.dim_tokens_enc = dim_tokens_enc * len(self.main_tasks) if dim_tokens_enc is not None else None
303
+ self.head_type = head_type
304
+
305
+ # Actual patch height and width, taking into account stride of input
306
+ self.P_H = max(1, self.patch_size[0] // stride_level)
307
+ self.P_W = max(1, self.patch_size[1] // stride_level)
308
+
309
+ self.scratch = make_scratch(layer_dims, feature_dim, groups=1, expand=False)
310
+
311
+ self.scratch.refinenet1 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
312
+ self.scratch.refinenet2 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
313
+ self.scratch.refinenet3 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
314
+ self.scratch.refinenet4 = make_fusion_block(feature_dim, use_bn, output_width_ratio)
315
+
316
+ if self.head_type == 'regression':
317
+ # The "DPTDepthModel" head
318
+ self.head = nn.Sequential(
319
+ nn.Conv2d(feature_dim, feature_dim // 2, kernel_size=3, stride=1, padding=1),
320
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
321
+ nn.Conv2d(feature_dim // 2, last_dim, kernel_size=3, stride=1, padding=1),
322
+ nn.ReLU(True),
323
+ nn.Conv2d(last_dim, self.num_channels, kernel_size=1, stride=1, padding=0)
324
+ )
325
+ elif self.head_type == 'semseg':
326
+ # The "DPTSegmentationModel" head
327
+ self.head = nn.Sequential(
328
+ nn.Conv2d(feature_dim, feature_dim, kernel_size=3, padding=1, bias=False),
329
+ nn.BatchNorm2d(feature_dim) if use_bn else nn.Identity(),
330
+ nn.ReLU(True),
331
+ nn.Dropout(0.1, False),
332
+ nn.Conv2d(feature_dim, self.num_channels, kernel_size=1),
333
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
334
+ )
335
+ else:
336
+ raise ValueError('DPT head_type must be "regression" or "semseg".')
337
+
338
+ if self.dim_tokens_enc is not None:
339
+ self.init(dim_tokens_enc=dim_tokens_enc)
340
+
341
+ def init(self, dim_tokens_enc=768):
342
+ """
343
+ Initialize parts of decoder that are dependent on dimension of encoder tokens.
344
+ Should be called when setting up MultiMAE.
345
+
346
+ :param dim_tokens_enc: Dimension of tokens coming from encoder
347
+ """
348
+ #print(dim_tokens_enc)
349
+
350
+ # Set up activation postprocessing layers
351
+ if isinstance(dim_tokens_enc, int):
352
+ dim_tokens_enc = 4 * [dim_tokens_enc]
353
+
354
+ self.dim_tokens_enc = [dt * len(self.main_tasks) for dt in dim_tokens_enc]
355
+
356
+ self.act_1_postprocess = nn.Sequential(
357
+ nn.Conv2d(
358
+ in_channels=self.dim_tokens_enc[0],
359
+ out_channels=self.layer_dims[0],
360
+ kernel_size=1, stride=1, padding=0,
361
+ ),
362
+ nn.ConvTranspose2d(
363
+ in_channels=self.layer_dims[0],
364
+ out_channels=self.layer_dims[0],
365
+ kernel_size=4, stride=4, padding=0,
366
+ bias=True, dilation=1, groups=1,
367
+ )
368
+ )
369
+
370
+ self.act_2_postprocess = nn.Sequential(
371
+ nn.Conv2d(
372
+ in_channels=self.dim_tokens_enc[1],
373
+ out_channels=self.layer_dims[1],
374
+ kernel_size=1, stride=1, padding=0,
375
+ ),
376
+ nn.ConvTranspose2d(
377
+ in_channels=self.layer_dims[1],
378
+ out_channels=self.layer_dims[1],
379
+ kernel_size=2, stride=2, padding=0,
380
+ bias=True, dilation=1, groups=1,
381
+ )
382
+ )
383
+
384
+ self.act_3_postprocess = nn.Sequential(
385
+ nn.Conv2d(
386
+ in_channels=self.dim_tokens_enc[2],
387
+ out_channels=self.layer_dims[2],
388
+ kernel_size=1, stride=1, padding=0,
389
+ )
390
+ )
391
+
392
+ self.act_4_postprocess = nn.Sequential(
393
+ nn.Conv2d(
394
+ in_channels=self.dim_tokens_enc[3],
395
+ out_channels=self.layer_dims[3],
396
+ kernel_size=1, stride=1, padding=0,
397
+ ),
398
+ nn.Conv2d(
399
+ in_channels=self.layer_dims[3],
400
+ out_channels=self.layer_dims[3],
401
+ kernel_size=3, stride=2, padding=1,
402
+ )
403
+ )
404
+
405
+ self.act_postprocess = nn.ModuleList([
406
+ self.act_1_postprocess,
407
+ self.act_2_postprocess,
408
+ self.act_3_postprocess,
409
+ self.act_4_postprocess
410
+ ])
411
+
412
+ def adapt_tokens(self, encoder_tokens):
413
+ # Adapt tokens
414
+ x = []
415
+ x.append(encoder_tokens[:, :])
416
+ x = torch.cat(x, dim=-1)
417
+ return x
418
+
419
+ def forward(self, encoder_tokens: List[torch.Tensor], image_size):
420
+ #input_info: Dict):
421
+ assert self.dim_tokens_enc is not None, 'Need to call init(dim_tokens_enc) function first'
422
+ H, W = image_size
423
+
424
+ # Number of patches in height and width
425
+ N_H = H // (self.stride_level * self.P_H)
426
+ N_W = W // (self.stride_level * self.P_W)
427
+
428
+ # Hook decoder onto 4 layers from specified ViT layers
429
+ layers = [encoder_tokens[hook] for hook in self.hooks]
430
+
431
+ # Extract only task-relevant tokens and ignore global tokens.
432
+ layers = [self.adapt_tokens(l) for l in layers]
433
+
434
+ # Reshape tokens to spatial representation
435
+ layers = [rearrange(l, 'b (nh nw) c -> b c nh nw', nh=N_H, nw=N_W) for l in layers]
436
+
437
+ layers = [self.act_postprocess[idx](l) for idx, l in enumerate(layers)]
438
+ # Project layers to chosen feature dim
439
+ layers = [self.scratch.layer_rn[idx](l) for idx, l in enumerate(layers)]
440
+
441
+ # Fuse layers using refinement stages
442
+ path_4 = self.scratch.refinenet4(layers[3])
443
+ path_3 = self.scratch.refinenet3(path_4, layers[2])
444
+ path_2 = self.scratch.refinenet2(path_3, layers[1])
445
+ path_1 = self.scratch.refinenet1(path_2, layers[0])
446
+
447
+ # Output head
448
+ out = self.head(path_1)
449
+
450
+ return out
third_party/dust3r/croco/models/head_downstream.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Heads for downstream tasks
6
+ # --------------------------------------------------------
7
+
8
+ """
9
+ A head is a module where the __init__ defines only the head hyperparameters.
10
+ A method setup(croconet) takes a CroCoNet and set all layers according to the head and croconet attributes.
11
+ The forward takes the features as well as a dictionary img_info containing the keys 'width' and 'height'
12
+ """
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from .dpt_block import DPTOutputAdapter
17
+
18
+
19
+ class PixelwiseTaskWithDPT(nn.Module):
20
+ """ DPT module for CroCo.
21
+ by default, hooks_idx will be equal to:
22
+ * for encoder-only: 4 equally spread layers
23
+ * for encoder+decoder: last encoder + 3 equally spread layers of the decoder
24
+ """
25
+
26
+ def __init__(self, *, hooks_idx=None, layer_dims=[96,192,384,768],
27
+ output_width_ratio=1, num_channels=1, postprocess=None, **kwargs):
28
+ super(PixelwiseTaskWithDPT, self).__init__()
29
+ self.return_all_blocks = True # backbone needs to return all layers
30
+ self.postprocess = postprocess
31
+ self.output_width_ratio = output_width_ratio
32
+ self.num_channels = num_channels
33
+ self.hooks_idx = hooks_idx
34
+ self.layer_dims = layer_dims
35
+
36
+ def setup(self, croconet):
37
+ dpt_args = {'output_width_ratio': self.output_width_ratio, 'num_channels': self.num_channels}
38
+ if self.hooks_idx is None:
39
+ if hasattr(croconet, 'dec_blocks'): # encoder + decoder
40
+ step = {8: 3, 12: 4, 24: 8}[croconet.dec_depth]
41
+ hooks_idx = [croconet.dec_depth+croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
42
+ else: # encoder only
43
+ step = croconet.enc_depth//4
44
+ hooks_idx = [croconet.enc_depth-1-i*step for i in range(3,-1,-1)]
45
+ self.hooks_idx = hooks_idx
46
+ print(f' PixelwiseTaskWithDPT: automatically setting hook_idxs={self.hooks_idx}')
47
+ dpt_args['hooks'] = self.hooks_idx
48
+ dpt_args['layer_dims'] = self.layer_dims
49
+ self.dpt = DPTOutputAdapter(**dpt_args)
50
+ dim_tokens = [croconet.enc_embed_dim if hook<croconet.enc_depth else croconet.dec_embed_dim for hook in self.hooks_idx]
51
+ dpt_init_args = {'dim_tokens_enc': dim_tokens}
52
+ self.dpt.init(**dpt_init_args)
53
+
54
+
55
+ def forward(self, x, img_info):
56
+ out = self.dpt(x, image_size=(img_info['height'],img_info['width']))
57
+ if self.postprocess: out = self.postprocess(out)
58
+ return out
third_party/dust3r/croco/models/masking.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Masking utils
7
+ # --------------------------------------------------------
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ class RandomMask(nn.Module):
13
+ """
14
+ random masking
15
+ """
16
+
17
+ def __init__(self, num_patches, mask_ratio):
18
+ super().__init__()
19
+ self.num_patches = num_patches
20
+ self.num_mask = int(mask_ratio * self.num_patches)
21
+
22
+ def __call__(self, x):
23
+ noise = torch.rand(x.size(0), self.num_patches, device=x.device)
24
+ argsort = torch.argsort(noise, dim=1)
25
+ return argsort < self.num_mask
third_party/dust3r/croco/models/pos_embed.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+
5
+ # --------------------------------------------------------
6
+ # Position embedding utils
7
+ # --------------------------------------------------------
8
+
9
+
10
+
11
+ import numpy as np
12
+
13
+ import torch
14
+
15
+ # --------------------------------------------------------
16
+ # 2D sine-cosine position embedding
17
+ # References:
18
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
19
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
20
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
21
+ # --------------------------------------------------------
22
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, n_cls_token=0):
23
+ """
24
+ grid_size: int of the grid height and width
25
+ return:
26
+ pos_embed: [grid_size*grid_size, embed_dim] or [n_cls_token+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
27
+ """
28
+ grid_h = np.arange(grid_size, dtype=np.float32)
29
+ grid_w = np.arange(grid_size, dtype=np.float32)
30
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
31
+ grid = np.stack(grid, axis=0)
32
+
33
+ grid = grid.reshape([2, 1, grid_size, grid_size])
34
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
35
+ if n_cls_token>0:
36
+ pos_embed = np.concatenate([np.zeros([n_cls_token, embed_dim]), pos_embed], axis=0)
37
+ return pos_embed
38
+
39
+
40
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
41
+ assert embed_dim % 2 == 0
42
+
43
+ # use half of dimensions to encode grid_h
44
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
45
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
46
+
47
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
48
+ return emb
49
+
50
+
51
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
52
+ """
53
+ embed_dim: output dimension for each position
54
+ pos: a list of positions to be encoded: size (M,)
55
+ out: (M, D)
56
+ """
57
+ assert embed_dim % 2 == 0
58
+ omega = np.arange(embed_dim // 2, dtype=float)
59
+ omega /= embed_dim / 2.
60
+ omega = 1. / 10000**omega # (D/2,)
61
+
62
+ pos = pos.reshape(-1) # (M,)
63
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
64
+
65
+ emb_sin = np.sin(out) # (M, D/2)
66
+ emb_cos = np.cos(out) # (M, D/2)
67
+
68
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
69
+ return emb
70
+
71
+
72
+ # --------------------------------------------------------
73
+ # Interpolate position embeddings for high-resolution
74
+ # References:
75
+ # MAE: https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
76
+ # DeiT: https://github.com/facebookresearch/deit
77
+ # --------------------------------------------------------
78
+ def interpolate_pos_embed(model, checkpoint_model):
79
+ if 'pos_embed' in checkpoint_model:
80
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
81
+ embedding_size = pos_embed_checkpoint.shape[-1]
82
+ num_patches = model.patch_embed.num_patches
83
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
84
+ # height (== width) for the checkpoint position embedding
85
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
86
+ # height (== width) for the new position embedding
87
+ new_size = int(num_patches ** 0.5)
88
+ # class_token and dist_token are kept unchanged
89
+ if orig_size != new_size:
90
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
91
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
92
+ # only the position tokens are interpolated
93
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
94
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
95
+ pos_tokens = torch.nn.functional.interpolate(
96
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
97
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
98
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
99
+ checkpoint_model['pos_embed'] = new_pos_embed
100
+
101
+
102
+ #----------------------------------------------------------
103
+ # RoPE2D: RoPE implementation in 2D
104
+ #----------------------------------------------------------
105
+
106
+ try:
107
+ from models.curope import cuRoPE2D
108
+ RoPE2D = cuRoPE2D
109
+ except ImportError:
110
+ print('Warning, cannot find cuda-compiled version of RoPE2D, using a slow pytorch version instead')
111
+
112
+ class RoPE2D(torch.nn.Module):
113
+
114
+ def __init__(self, freq=100.0, F0=1.0):
115
+ super().__init__()
116
+ self.base = freq
117
+ self.F0 = F0
118
+ self.cache = {}
119
+
120
+ def get_cos_sin(self, D, seq_len, device, dtype):
121
+ if (D,seq_len,device,dtype) not in self.cache:
122
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, D, 2).float().to(device) / D))
123
+ t = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
124
+ freqs = torch.einsum("i,j->ij", t, inv_freq).to(dtype)
125
+ freqs = torch.cat((freqs, freqs), dim=-1)
126
+ cos = freqs.cos() # (Seq, Dim)
127
+ sin = freqs.sin()
128
+ self.cache[D,seq_len,device,dtype] = (cos,sin)
129
+ return self.cache[D,seq_len,device,dtype]
130
+
131
+ @staticmethod
132
+ def rotate_half(x):
133
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
134
+ return torch.cat((-x2, x1), dim=-1)
135
+
136
+ def apply_rope1d(self, tokens, pos1d, cos, sin):
137
+ assert pos1d.ndim==2
138
+ cos = torch.nn.functional.embedding(pos1d, cos)[:, None, :, :]
139
+ sin = torch.nn.functional.embedding(pos1d, sin)[:, None, :, :]
140
+ return (tokens * cos) + (self.rotate_half(tokens) * sin)
141
+
142
+ def forward(self, tokens, positions):
143
+ """
144
+ input:
145
+ * tokens: batch_size x nheads x ntokens x dim
146
+ * positions: batch_size x ntokens x 2 (y and x position of each token)
147
+ output:
148
+ * tokens after appplying RoPE2D (batch_size x nheads x ntokens x dim)
149
+ """
150
+ assert tokens.size(3)%2==0, "number of dimensions should be a multiple of two"
151
+ D = tokens.size(3) // 2
152
+ assert positions.ndim==3 and positions.shape[-1] == 2 # Batch, Seq, 2
153
+ cos, sin = self.get_cos_sin(D, int(positions.max())+1, tokens.device, tokens.dtype)
154
+ # split features into two along the feature dimension, and apply rope1d on each half
155
+ y, x = tokens.chunk(2, dim=-1)
156
+ y = self.apply_rope1d(y, positions[:,:,0], cos, sin)
157
+ x = self.apply_rope1d(x, positions[:,:,1], cos, sin)
158
+ tokens = torch.cat((y, x), dim=-1)
159
+ return tokens
third_party/dust3r/croco/pretrain.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # Pre-training CroCo
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # MAE: https://github.com/facebookresearch/mae
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
11
+ # --------------------------------------------------------
12
+ import argparse
13
+ import datetime
14
+ import json
15
+ import numpy as np
16
+ import os
17
+ import sys
18
+ import time
19
+ import math
20
+ from pathlib import Path
21
+ from typing import Iterable
22
+
23
+ import torch
24
+ import torch.distributed as dist
25
+ import torch.backends.cudnn as cudnn
26
+ from torch.utils.tensorboard import SummaryWriter
27
+ import torchvision.transforms as transforms
28
+ import torchvision.datasets as datasets
29
+
30
+ import utils.misc as misc
31
+ from utils.misc import NativeScalerWithGradNormCount as NativeScaler
32
+ from models.croco import CroCoNet
33
+ from models.criterion import MaskedMSE
34
+ from datasets.pairs_dataset import PairsDataset
35
+
36
+
37
+ def get_args_parser():
38
+ parser = argparse.ArgumentParser('CroCo pre-training', add_help=False)
39
+ # model and criterion
40
+ parser.add_argument('--model', default='CroCoNet()', type=str, help="string containing the model to build")
41
+ parser.add_argument('--norm_pix_loss', default=1, choices=[0,1], help="apply per-patch mean/std normalization before applying the loss")
42
+ # dataset
43
+ parser.add_argument('--dataset', default='habitat_release', type=str, help="training set")
44
+ parser.add_argument('--transforms', default='crop224+acolor', type=str, help="transforms to apply") # in the paper, we also use some homography and rotation, but find later that they were not useful or even harmful
45
+ # training
46
+ parser.add_argument('--seed', default=0, type=int, help="Random seed")
47
+ parser.add_argument('--batch_size', default=64, type=int, help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus")
48
+ parser.add_argument('--epochs', default=800, type=int, help="Maximum number of epochs for the scheduler")
49
+ parser.add_argument('--max_epoch', default=400, type=int, help="Stop training at this epoch")
50
+ parser.add_argument('--accum_iter', default=1, type=int, help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)")
51
+ parser.add_argument('--weight_decay', type=float, default=0.05, help="weight decay (default: 0.05)")
52
+ parser.add_argument('--lr', type=float, default=None, metavar='LR', help='learning rate (absolute lr)')
53
+ parser.add_argument('--blr', type=float, default=1.5e-4, metavar='LR', help='base learning rate: absolute_lr = base_lr * total_batch_size / 256')
54
+ parser.add_argument('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0')
55
+ parser.add_argument('--warmup_epochs', type=int, default=40, metavar='N', help='epochs to warmup LR')
56
+ parser.add_argument('--amp', type=int, default=1, choices=[0,1], help="Use Automatic Mixed Precision for pretraining")
57
+ # others
58
+ parser.add_argument('--num_workers', default=8, type=int)
59
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
60
+ parser.add_argument('--local_rank', default=-1, type=int)
61
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
62
+ parser.add_argument('--save_freq', default=1, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-last.pth')
63
+ parser.add_argument('--keep_freq', default=20, type=int, help='frequence (number of epochs) to save checkpoint in checkpoint-%d.pth')
64
+ parser.add_argument('--print_freq', default=20, type=int, help='frequence (number of iterations) to print infos while training')
65
+ # paths
66
+ parser.add_argument('--output_dir', default='./output/', type=str, help="path where to save the output")
67
+ parser.add_argument('--data_dir', default='./data/', type=str, help="path where data are stored")
68
+ return parser
69
+
70
+
71
+
72
+
73
+ def main(args):
74
+ misc.init_distributed_mode(args)
75
+ global_rank = misc.get_rank()
76
+ world_size = misc.get_world_size()
77
+
78
+ print("output_dir: "+args.output_dir)
79
+ if args.output_dir:
80
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
81
+
82
+ # auto resume
83
+ last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth')
84
+ args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None
85
+
86
+ print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
87
+ print("{}".format(args).replace(', ', ',\n'))
88
+
89
+ device = "cuda" if torch.cuda.is_available() else "cpu"
90
+ device = torch.device(device)
91
+
92
+ # fix the seed
93
+ seed = args.seed + misc.get_rank()
94
+ torch.manual_seed(seed)
95
+ np.random.seed(seed)
96
+
97
+ cudnn.benchmark = True
98
+
99
+ ## training dataset and loader
100
+ print('Building dataset for {:s} with transforms {:s}'.format(args.dataset, args.transforms))
101
+ dataset = PairsDataset(args.dataset, trfs=args.transforms, data_dir=args.data_dir)
102
+ if world_size>1:
103
+ sampler_train = torch.utils.data.DistributedSampler(
104
+ dataset, num_replicas=world_size, rank=global_rank, shuffle=True
105
+ )
106
+ print("Sampler_train = %s" % str(sampler_train))
107
+ else:
108
+ sampler_train = torch.utils.data.RandomSampler(dataset)
109
+ data_loader_train = torch.utils.data.DataLoader(
110
+ dataset, sampler=sampler_train,
111
+ batch_size=args.batch_size,
112
+ num_workers=args.num_workers,
113
+ pin_memory=True,
114
+ drop_last=True,
115
+ )
116
+
117
+ ## model
118
+ print('Loading model: {:s}'.format(args.model))
119
+ model = eval(args.model)
120
+ print('Loading criterion: MaskedMSE(norm_pix_loss={:s})'.format(str(bool(args.norm_pix_loss))))
121
+ criterion = MaskedMSE(norm_pix_loss=bool(args.norm_pix_loss))
122
+
123
+ model.to(device)
124
+ model_without_ddp = model
125
+ print("Model = %s" % str(model_without_ddp))
126
+
127
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
128
+ if args.lr is None: # only base_lr is specified
129
+ args.lr = args.blr * eff_batch_size / 256
130
+ print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
131
+ print("actual lr: %.2e" % args.lr)
132
+ print("accumulate grad iterations: %d" % args.accum_iter)
133
+ print("effective batch size: %d" % eff_batch_size)
134
+
135
+ if args.distributed:
136
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True, static_graph=True)
137
+ model_without_ddp = model.module
138
+
139
+ param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay) # following timm: set wd as 0 for bias and norm layers
140
+ optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))
141
+ print(optimizer)
142
+ loss_scaler = NativeScaler()
143
+
144
+ misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
145
+
146
+ if global_rank == 0 and args.output_dir is not None:
147
+ log_writer = SummaryWriter(log_dir=args.output_dir)
148
+ else:
149
+ log_writer = None
150
+
151
+ print(f"Start training until {args.max_epoch} epochs")
152
+ start_time = time.time()
153
+ for epoch in range(args.start_epoch, args.max_epoch):
154
+ if world_size>1:
155
+ data_loader_train.sampler.set_epoch(epoch)
156
+
157
+ train_stats = train_one_epoch(
158
+ model, criterion, data_loader_train,
159
+ optimizer, device, epoch, loss_scaler,
160
+ log_writer=log_writer,
161
+ args=args
162
+ )
163
+
164
+ if args.output_dir and epoch % args.save_freq == 0 :
165
+ misc.save_model(
166
+ args=args, model_without_ddp=model_without_ddp, optimizer=optimizer,
167
+ loss_scaler=loss_scaler, epoch=epoch, fname='last')
168
+
169
+ if args.output_dir and (epoch % args.keep_freq == 0 or epoch + 1 == args.max_epoch) and (epoch>0 or args.max_epoch==1):
170
+ misc.save_model(
171
+ args=args, model_without_ddp=model_without_ddp, optimizer=optimizer,
172
+ loss_scaler=loss_scaler, epoch=epoch)
173
+
174
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
175
+ 'epoch': epoch,}
176
+
177
+ if args.output_dir and misc.is_main_process():
178
+ if log_writer is not None:
179
+ log_writer.flush()
180
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
181
+ f.write(json.dumps(log_stats) + "\n")
182
+
183
+ total_time = time.time() - start_time
184
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
185
+ print('Training time {}'.format(total_time_str))
186
+
187
+
188
+
189
+
190
+ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
191
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
192
+ device: torch.device, epoch: int, loss_scaler,
193
+ log_writer=None,
194
+ args=None):
195
+ model.train(True)
196
+ metric_logger = misc.MetricLogger(delimiter=" ")
197
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
198
+ header = 'Epoch: [{}]'.format(epoch)
199
+ accum_iter = args.accum_iter
200
+
201
+ optimizer.zero_grad()
202
+
203
+ if log_writer is not None:
204
+ print('log_dir: {}'.format(log_writer.log_dir))
205
+
206
+ for data_iter_step, (image1, image2) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
207
+
208
+ # we use a per iteration lr scheduler
209
+ if data_iter_step % accum_iter == 0:
210
+ misc.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
211
+
212
+ image1 = image1.to(device, non_blocking=True)
213
+ image2 = image2.to(device, non_blocking=True)
214
+ with torch.cuda.amp.autocast(enabled=bool(args.amp)):
215
+ out, mask, target = model(image1, image2)
216
+ loss = criterion(out, mask, target)
217
+
218
+ loss_value = loss.item()
219
+
220
+ if not math.isfinite(loss_value):
221
+ print("Loss is {}, stopping training".format(loss_value))
222
+ sys.exit(1)
223
+
224
+ loss /= accum_iter
225
+ loss_scaler(loss, optimizer, parameters=model.parameters(),
226
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
227
+ if (data_iter_step + 1) % accum_iter == 0:
228
+ optimizer.zero_grad()
229
+
230
+ torch.cuda.synchronize()
231
+
232
+ metric_logger.update(loss=loss_value)
233
+
234
+ lr = optimizer.param_groups[0]["lr"]
235
+ metric_logger.update(lr=lr)
236
+
237
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
238
+ if log_writer is not None and ((data_iter_step + 1) % (accum_iter*args.print_freq)) == 0:
239
+ # x-axis is based on epoch_1000x in the tensorboard, calibrating differences curves when batch size changes
240
+ epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
241
+ log_writer.add_scalar('train_loss', loss_value_reduce, epoch_1000x)
242
+ log_writer.add_scalar('lr', lr, epoch_1000x)
243
+
244
+ # gather the stats from all processes
245
+ metric_logger.synchronize_between_processes()
246
+ print("Averaged stats:", metric_logger)
247
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
248
+
249
+
250
+
251
+ if __name__ == '__main__':
252
+ args = get_args_parser()
253
+ args = args.parse_args()
254
+ main(args)
third_party/dust3r/croco/stereoflow/README.MD ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## CroCo-Stereo and CroCo-Flow
2
+
3
+ This README explains how to use CroCo-Stereo and CroCo-Flow as well as how they were trained.
4
+ All commands should be launched from the root directory.
5
+
6
+ ### Simple inference example
7
+
8
+ We provide a simple inference exemple for CroCo-Stereo and CroCo-Flow in the Totebook `croco-stereo-flow-demo.ipynb`.
9
+ Before running it, please download the trained models with:
10
+ ```
11
+ bash stereoflow/download_model.sh crocostereo.pth
12
+ bash stereoflow/download_model.sh crocoflow.pth
13
+ ```
14
+
15
+ ### Prepare data for training or evaluation
16
+
17
+ Put the datasets used for training/evaluation in `./data/stereoflow` (or update the paths at the top of `stereoflow/datasets_stereo.py` and `stereoflow/datasets_flow.py`).
18
+ Please find below on the file structure should look for each dataset:
19
+ <details>
20
+ <summary>FlyingChairs</summary>
21
+
22
+ ```
23
+ ./data/stereoflow/FlyingChairs/
24
+ └───chairs_split.txt
25
+ └───data/
26
+ └─── ...
27
+ ```
28
+ </details>
29
+
30
+ <details>
31
+ <summary>MPI-Sintel</summary>
32
+
33
+ ```
34
+ ./data/stereoflow/MPI-Sintel/
35
+ └───training/
36
+ │ └───clean/
37
+ │ └───final/
38
+ │ └───flow/
39
+ └───test/
40
+ └───clean/
41
+ └───final/
42
+ ```
43
+ </details>
44
+
45
+ <details>
46
+ <summary>SceneFlow (including FlyingThings)</summary>
47
+
48
+ ```
49
+ ./data/stereoflow/SceneFlow/
50
+ └───Driving/
51
+ │ └───disparity/
52
+ │ └───frames_cleanpass/
53
+ │ └───frames_finalpass/
54
+ └───FlyingThings/
55
+ │ └───disparity/
56
+ │ └───frames_cleanpass/
57
+ │ └───frames_finalpass/
58
+ │ └───optical_flow/
59
+ └───Monkaa/
60
+ └───disparity/
61
+ └───frames_cleanpass/
62
+ └───frames_finalpass/
63
+ ```
64
+ </details>
65
+
66
+ <details>
67
+ <summary>TartanAir</summary>
68
+
69
+ ```
70
+ ./data/stereoflow/TartanAir/
71
+ └───abandonedfactory/
72
+ │ └───.../
73
+ └───abandonedfactory_night/
74
+ │ └───.../
75
+ └───.../
76
+ ```
77
+ </details>
78
+
79
+ <details>
80
+ <summary>Booster</summary>
81
+
82
+ ```
83
+ ./data/stereoflow/booster_gt/
84
+ └───train/
85
+ └───balanced/
86
+ └───Bathroom/
87
+ └───Bedroom/
88
+ └───...
89
+ ```
90
+ </details>
91
+
92
+ <details>
93
+ <summary>CREStereo</summary>
94
+
95
+ ```
96
+ ./data/stereoflow/crenet_stereo_trainset/
97
+ └───stereo_trainset/
98
+ └───crestereo/
99
+ └───hole/
100
+ └───reflective/
101
+ └───shapenet/
102
+ └───tree/
103
+ ```
104
+ </details>
105
+
106
+ <details>
107
+ <summary>ETH3D Two-view Low-res</summary>
108
+
109
+ ```
110
+ ./data/stereoflow/eth3d_lowres/
111
+ └───test/
112
+ │ └───lakeside_1l/
113
+ │ └───...
114
+ └───train/
115
+ │ └───delivery_area_1l/
116
+ │ └───...
117
+ └───train_gt/
118
+ └───delivery_area_1l/
119
+ └───...
120
+ ```
121
+ </details>
122
+
123
+ <details>
124
+ <summary>KITTI 2012</summary>
125
+
126
+ ```
127
+ ./data/stereoflow/kitti-stereo-2012/
128
+ └───testing/
129
+ │ └───colored_0/
130
+ │ └───colored_1/
131
+ └───training/
132
+ └───colored_0/
133
+ └───colored_1/
134
+ └───disp_occ/
135
+ └───flow_occ/
136
+ ```
137
+ </details>
138
+
139
+ <details>
140
+ <summary>KITTI 2015</summary>
141
+
142
+ ```
143
+ ./data/stereoflow/kitti-stereo-2015/
144
+ └───testing/
145
+ │ └───image_2/
146
+ │ └───image_3/
147
+ └───training/
148
+ └───image_2/
149
+ └───image_3/
150
+ └───disp_occ_0/
151
+ └───flow_occ/
152
+ ```
153
+ </details>
154
+
155
+ <details>
156
+ <summary>Middlebury</summary>
157
+
158
+ ```
159
+ ./data/stereoflow/middlebury
160
+ └───2005/
161
+ │ └───train/
162
+ │ └───Art/
163
+ │ └───...
164
+ └───2006/
165
+ │ └───Aloe/
166
+ │ └───Baby1/
167
+ │ └───...
168
+ └───2014/
169
+ │ └───Adirondack-imperfect/
170
+ │ └───Adirondack-perfect/
171
+ │ └───...
172
+ └───2021/
173
+ │ └───data/
174
+ │ └───artroom1/
175
+ │ └───artroom2/
176
+ │ └───...
177
+ └───MiddEval3_F/
178
+ └───test/
179
+ │ └───Australia/
180
+ │ └───...
181
+ └───train/
182
+ └───Adirondack/
183
+ └───...
184
+ ```
185
+ </details>
186
+
187
+ <details>
188
+ <summary>Spring</summary>
189
+
190
+ ```
191
+ ./data/stereoflow/spring/
192
+ └───test/
193
+ │ └───0003/
194
+ │ └───...
195
+ └───train/
196
+ └───0001/
197
+ └───...
198
+ ```
199
+ </details>
200
+
201
+
202
+ ### CroCo-Stereo
203
+
204
+ ##### Main model
205
+
206
+ The main training of CroCo-Stereo was performed on a series of datasets, and it was used as it for Middlebury v3 benchmark.
207
+
208
+ ```
209
+ # Download the model
210
+ bash stereoflow/download_model.sh crocostereo.pth
211
+ # Middlebury v3 submission
212
+ python stereoflow/test.py --model stereoflow_models/crocostereo.pth --dataset "MdEval3('all_full')" --save submission --tile_overlap 0.9
213
+ # Training command that was used, using checkpoint-last.pth
214
+ python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main/
215
+ # or it can be launched on multiple gpus (while maintaining the effective batch size), e.g. on 3 gpus:
216
+ torchrun --nproc_per_node 3 stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 2 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main/
217
+ ```
218
+
219
+ For evaluation of validation set, we also provide the model trained on the `subtrain` subset of the training sets.
220
+
221
+ ```
222
+ # Download the model
223
+ bash stereoflow/download_model.sh crocostereo_subtrain.pth
224
+ # Evaluation on validation sets
225
+ python stereoflow/test.py --model stereoflow_models/crocostereo_subtrain.pth --dataset "MdEval3('subval_full')+ETH3DLowRes('subval')+SceneFlow('test_finalpass')+SceneFlow('test_cleanpass')" --save metrics --tile_overlap 0.9
226
+ # Training command that was used (same as above but on subtrain, using checkpoint-best.pth), can also be launched on multiple gpus
227
+ python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('subtrain')+50*Md05('subtrain')+50*Md06('subtrain')+50*Md14('subtrain')+50*Md21('subtrain')+50*MdEval3('subtrain_full')+Booster('subtrain_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main_subtrain/
228
+ ```
229
+
230
+ ##### Other models
231
+
232
+ <details>
233
+ <summary>Model for ETH3D</summary>
234
+ The model used for the submission on ETH3D is trained with the same command but using an unbounded Laplacian loss.
235
+
236
+ # Download the model
237
+ bash stereoflow/download_model.sh crocostereo_eth3d.pth
238
+ # ETH3D submission
239
+ python stereoflow/test.py --model stereoflow_models/crocostereo_eth3d.pth --dataset "ETH3DLowRes('all')" --save submission --tile_overlap 0.9
240
+ # Training command that was used
241
+ python -u stereoflow/train.py stereo --criterion "LaplacianLoss()" --tile_conf_mode conf_expbeta3 --dataset "CREStereo('train')+SceneFlow('train_allpass')+30*ETH3DLowRes('train')+50*Md05('train')+50*Md06('train')+50*Md14('train')+50*Md21('train')+50*MdEval3('train_full')+Booster('train_balanced')" --val_dataset "SceneFlow('test1of100_finalpass')+SceneFlow('test1of100_cleanpass')+ETH3DLowRes('subval')+Md05('subval')+Md06('subval')+Md14('subval')+Md21('subval')+MdEval3('subval_full')+Booster('subval_balanced')" --lr 3e-5 --batch_size 6 --epochs 32 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocostereo/main_eth3d/
242
+
243
+ </details>
244
+
245
+ <details>
246
+ <summary>Main model finetuned on Kitti</summary>
247
+
248
+ # Download the model
249
+ bash stereoflow/download_model.sh crocostereo_finetune_kitti.pth
250
+ # Kitti submission
251
+ python stereoflow/test.py --model stereoflow_models/crocostereo_finetune_kitti.pth --dataset "Kitti15('test')" --save submission --tile_overlap 0.9
252
+ # Training that was used
253
+ python -u stereoflow/train.py stereo --crop 352 1216 --criterion "LaplacianLossBounded2()" --dataset "Kitti12('train')+Kitti15('train')" --lr 3e-5 --batch_size 1 --accum_iter 6 --epochs 20 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocostereo.pth --output_dir xps/crocostereo/finetune_kitti/ --save_every 5
254
+ </details>
255
+
256
+ <details>
257
+ <summary>Main model finetuned on Spring</summary>
258
+
259
+ # Download the model
260
+ bash stereoflow/download_model.sh crocostereo_finetune_spring.pth
261
+ # Spring submission
262
+ python stereoflow/test.py --model stereoflow_models/crocostereo_finetune_spring.pth --dataset "Spring('test')" --save submission --tile_overlap 0.9
263
+ # Training command that was used
264
+ python -u stereoflow/train.py stereo --criterion "LaplacianLossBounded2()" --dataset "Spring('train')" --lr 3e-5 --batch_size 6 --epochs 8 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocostereo.pth --output_dir xps/crocostereo/finetune_spring/
265
+ </details>
266
+
267
+ <details>
268
+ <summary>Smaller models</summary>
269
+ To train CroCo-Stereo with smaller CroCo pretrained models, simply replace the <code>--pretrained</code> argument. To download the smaller CroCo-Stereo models based on CroCo v2 pretraining with ViT-Base encoder and Small encoder, use <code>bash stereoflow/download_model.sh crocostereo_subtrain_vitb_smalldecoder.pth</code>, and for the model with a ViT-Base encoder and a Base decoder, use <code>bash stereoflow/download_model.sh crocostereo_subtrain_vitb_basedecoder.pth</code>.
270
+ </details>
271
+
272
+
273
+ ### CroCo-Flow
274
+
275
+ ##### Main model
276
+
277
+ The main training of CroCo-Flow was performed on the FlyingThings, FlyingChairs, MPI-Sintel and TartanAir datasets.
278
+ It was used for our submission to the MPI-Sintel benchmark.
279
+
280
+ ```
281
+ # Download the model
282
+ bash stereoflow/download_model.sh crocoflow.pth
283
+ # Evaluation
284
+ python stereoflow/test.py --model stereoflow_models/crocoflow.pth --dataset "MPISintel('subval_cleanpass')+MPISintel('subval_finalpass')" --save metrics --tile_overlap 0.9
285
+ # Sintel submission
286
+ python stereoflow/test.py --model stereoflow_models/crocoflow.pth --dataset "MPISintel('test_allpass')" --save submission --tile_overlap 0.9
287
+ # Training command that was used, with checkpoint-best.pth
288
+ python -u stereoflow/train.py flow --criterion "LaplacianLossBounded()" --dataset "40*MPISintel('subtrain_cleanpass')+40*MPISintel('subtrain_finalpass')+4*FlyingThings('train_allpass')+4*FlyingChairs('train')+TartanAir('train')" --val_dataset "MPISintel('subval_cleanpass')+MPISintel('subval_finalpass')" --lr 2e-5 --batch_size 8 --epochs 240 --img_per_epoch 30000 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --output_dir xps/crocoflow/main/
289
+ ```
290
+
291
+ ##### Other models
292
+
293
+ <details>
294
+ <summary>Main model finetuned on Kitti</summary>
295
+
296
+ # Download the model
297
+ bash stereoflow/download_model.sh crocoflow_finetune_kitti.pth
298
+ # Kitti submission
299
+ python stereoflow/test.py --model stereoflow_models/crocoflow_finetune_kitti.pth --dataset "Kitti15('test')" --save submission --tile_overlap 0.99
300
+ # Training that was used, with checkpoint-last.pth
301
+ python -u stereoflow/train.py flow --crop 352 1216 --criterion "LaplacianLossBounded()" --dataset "Kitti15('train')+Kitti12('train')" --lr 2e-5 --batch_size 1 --accum_iter 8 --epochs 150 --save_every 5 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocoflow.pth --output_dir xps/crocoflow/finetune_kitti/
302
+ </details>
303
+
304
+ <details>
305
+ <summary>Main model finetuned on Spring</summary>
306
+
307
+ # Download the model
308
+ bash stereoflow/download_model.sh crocoflow_finetune_spring.pth
309
+ # Spring submission
310
+ python stereoflow/test.py --model stereoflow_models/crocoflow_finetune_spring.pth --dataset "Spring('test')" --save submission --tile_overlap 0.9
311
+ # Training command that was used, with checkpoint-last.pth
312
+ python -u stereoflow/train.py flow --criterion "LaplacianLossBounded()" --dataset "Spring('train')" --lr 2e-5 --batch_size 8 --epochs 12 --pretrained pretrained_models/CroCo_V2_ViTLarge_BaseDecoder.pth --start_from stereoflow_models/crocoflow.pth --output_dir xps/crocoflow/finetune_spring/
313
+ </details>
314
+
315
+ <details>
316
+ <summary>Smaller models</summary>
317
+ To train CroCo-Flow with smaller CroCo pretrained models, simply replace the <code>--pretrained</code> argument. To download the smaller CroCo-Flow models based on CroCo v2 pretraining with ViT-Base encoder and Small encoder, use <code>bash stereoflow/download_model.sh crocoflow_vitb_smalldecoder.pth</code>, and for the model with a ViT-Base encoder and a Base decoder, use <code>bash stereoflow/download_model.sh crocoflow_vitb_basedecoder.pth</code>.
318
+ </details>
third_party/dust3r/croco/stereoflow/augmentor.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Data augmentation for training stereo and flow
6
+ # --------------------------------------------------------
7
+
8
+ # References
9
+ # https://github.com/autonomousvision/unimatch/blob/master/dataloader/stereo/transforms.py
10
+ # https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/transforms.py
11
+
12
+
13
+ import numpy as np
14
+ import random
15
+ from PIL import Image
16
+
17
+ import cv2
18
+ cv2.setNumThreads(0)
19
+ cv2.ocl.setUseOpenCL(False)
20
+
21
+ import torch
22
+ from torchvision.transforms import ColorJitter
23
+ import torchvision.transforms.functional as FF
24
+
25
+ class StereoAugmentor(object):
26
+
27
+ def __init__(self, crop_size, scale_prob=0.5, scale_xonly=True, lhth=800., lminscale=0.0, lmaxscale=1.0, hminscale=-0.2, hmaxscale=0.4, scale_interp_nearest=True, rightjitterprob=0.5, v_flip_prob=0.5, color_aug_asym=True, color_choice_prob=0.5):
28
+ self.crop_size = crop_size
29
+ self.scale_prob = scale_prob
30
+ self.scale_xonly = scale_xonly
31
+ self.lhth = lhth
32
+ self.lminscale = lminscale
33
+ self.lmaxscale = lmaxscale
34
+ self.hminscale = hminscale
35
+ self.hmaxscale = hmaxscale
36
+ self.scale_interp_nearest = scale_interp_nearest
37
+ self.rightjitterprob = rightjitterprob
38
+ self.v_flip_prob = v_flip_prob
39
+ self.color_aug_asym = color_aug_asym
40
+ self.color_choice_prob = color_choice_prob
41
+
42
+ def _random_scale(self, img1, img2, disp):
43
+ ch,cw = self.crop_size
44
+ h,w = img1.shape[:2]
45
+ if self.scale_prob>0. and np.random.rand()<self.scale_prob:
46
+ min_scale, max_scale = (self.lminscale,self.lmaxscale) if min(h,w) < self.lhth else (self.hminscale,self.hmaxscale)
47
+ scale_x = 2. ** np.random.uniform(min_scale, max_scale)
48
+ scale_x = np.clip(scale_x, (cw+8) / float(w), None)
49
+ scale_y = 1.
50
+ if not self.scale_xonly:
51
+ scale_y = scale_x
52
+ scale_y = np.clip(scale_y, (ch+8) / float(h), None)
53
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
54
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
55
+ disp = cv2.resize(disp, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR if not self.scale_interp_nearest else cv2.INTER_NEAREST) * scale_x
56
+ else: # check if we need to resize to be able to crop
57
+ h,w = img1.shape[:2]
58
+ clip_scale = (cw+8) / float(w)
59
+ if clip_scale>1.:
60
+ scale_x = clip_scale
61
+ scale_y = scale_x if not self.scale_xonly else 1.0
62
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
63
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
64
+ disp = cv2.resize(disp, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR if not self.scale_interp_nearest else cv2.INTER_NEAREST) * scale_x
65
+ return img1, img2, disp
66
+
67
+ def _random_crop(self, img1, img2, disp):
68
+ h,w = img1.shape[:2]
69
+ ch,cw = self.crop_size
70
+ assert ch<=h and cw<=w, (img1.shape, h,w,ch,cw)
71
+ offset_x = np.random.randint(w - cw + 1)
72
+ offset_y = np.random.randint(h - ch + 1)
73
+ img1 = img1[offset_y:offset_y+ch,offset_x:offset_x+cw]
74
+ img2 = img2[offset_y:offset_y+ch,offset_x:offset_x+cw]
75
+ disp = disp[offset_y:offset_y+ch,offset_x:offset_x+cw]
76
+ return img1, img2, disp
77
+
78
+ def _random_vflip(self, img1, img2, disp):
79
+ # vertical flip
80
+ if self.v_flip_prob>0 and np.random.rand() < self.v_flip_prob:
81
+ img1 = np.copy(np.flipud(img1))
82
+ img2 = np.copy(np.flipud(img2))
83
+ disp = np.copy(np.flipud(disp))
84
+ return img1, img2, disp
85
+
86
+ def _random_rotate_shift_right(self, img2):
87
+ if self.rightjitterprob>0. and np.random.rand()<self.rightjitterprob:
88
+ angle, pixel = 0.1, 2
89
+ px = np.random.uniform(-pixel, pixel)
90
+ ag = np.random.uniform(-angle, angle)
91
+ image_center = (np.random.uniform(0, img2.shape[0]), np.random.uniform(0, img2.shape[1]) )
92
+ rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0)
93
+ img2 = cv2.warpAffine(img2, rot_mat, img2.shape[1::-1], flags=cv2.INTER_LINEAR)
94
+ trans_mat = np.float32([[1, 0, 0], [0, 1, px]])
95
+ img2 = cv2.warpAffine(img2, trans_mat, img2.shape[1::-1], flags=cv2.INTER_LINEAR)
96
+ return img2
97
+
98
+ def _random_color_contrast(self, img1, img2):
99
+ if np.random.random() < 0.5:
100
+ contrast_factor = np.random.uniform(0.8, 1.2)
101
+ img1 = FF.adjust_contrast(img1, contrast_factor)
102
+ if self.color_aug_asym and np.random.random() < 0.5: contrast_factor = np.random.uniform(0.8, 1.2)
103
+ img2 = FF.adjust_contrast(img2, contrast_factor)
104
+ return img1, img2
105
+ def _random_color_gamma(self, img1, img2):
106
+ if np.random.random() < 0.5:
107
+ gamma = np.random.uniform(0.7, 1.5)
108
+ img1 = FF.adjust_gamma(img1, gamma)
109
+ if self.color_aug_asym and np.random.random() < 0.5: gamma = np.random.uniform(0.7, 1.5)
110
+ img2 = FF.adjust_gamma(img2, gamma)
111
+ return img1, img2
112
+ def _random_color_brightness(self, img1, img2):
113
+ if np.random.random() < 0.5:
114
+ brightness = np.random.uniform(0.5, 2.0)
115
+ img1 = FF.adjust_brightness(img1, brightness)
116
+ if self.color_aug_asym and np.random.random() < 0.5: brightness = np.random.uniform(0.5, 2.0)
117
+ img2 = FF.adjust_brightness(img2, brightness)
118
+ return img1, img2
119
+ def _random_color_hue(self, img1, img2):
120
+ if np.random.random() < 0.5:
121
+ hue = np.random.uniform(-0.1, 0.1)
122
+ img1 = FF.adjust_hue(img1, hue)
123
+ if self.color_aug_asym and np.random.random() < 0.5: hue = np.random.uniform(-0.1, 0.1)
124
+ img2 = FF.adjust_hue(img2, hue)
125
+ return img1, img2
126
+ def _random_color_saturation(self, img1, img2):
127
+ if np.random.random() < 0.5:
128
+ saturation = np.random.uniform(0.8, 1.2)
129
+ img1 = FF.adjust_saturation(img1, saturation)
130
+ if self.color_aug_asym and np.random.random() < 0.5: saturation = np.random.uniform(-0.8,1.2)
131
+ img2 = FF.adjust_saturation(img2, saturation)
132
+ return img1, img2
133
+ def _random_color(self, img1, img2):
134
+ trfs = [self._random_color_contrast,self._random_color_gamma,self._random_color_brightness,self._random_color_hue,self._random_color_saturation]
135
+ img1 = Image.fromarray(img1.astype('uint8'))
136
+ img2 = Image.fromarray(img2.astype('uint8'))
137
+ if np.random.random() < self.color_choice_prob:
138
+ # A single transform
139
+ t = random.choice(trfs)
140
+ img1, img2 = t(img1, img2)
141
+ else:
142
+ # Combination of trfs
143
+ # Random order
144
+ random.shuffle(trfs)
145
+ for t in trfs:
146
+ img1, img2 = t(img1, img2)
147
+ img1 = np.array(img1).astype(np.float32)
148
+ img2 = np.array(img2).astype(np.float32)
149
+ return img1, img2
150
+
151
+ def __call__(self, img1, img2, disp, dataset_name):
152
+ img1, img2, disp = self._random_scale(img1, img2, disp)
153
+ img1, img2, disp = self._random_crop(img1, img2, disp)
154
+ img1, img2, disp = self._random_vflip(img1, img2, disp)
155
+ img2 = self._random_rotate_shift_right(img2)
156
+ img1, img2 = self._random_color(img1, img2)
157
+ return img1, img2, disp
158
+
159
+
160
+
161
+ class FlowAugmentor:
162
+
163
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, spatial_aug_prob=0.8, stretch_prob=0.8, max_stretch=0.2, h_flip_prob=0.5, v_flip_prob=0.1, asymmetric_color_aug_prob=0.2):
164
+
165
+ # spatial augmentation params
166
+ self.crop_size = crop_size
167
+ self.min_scale = min_scale
168
+ self.max_scale = max_scale
169
+ self.spatial_aug_prob = spatial_aug_prob
170
+ self.stretch_prob = stretch_prob
171
+ self.max_stretch = max_stretch
172
+
173
+ # flip augmentation params
174
+ self.h_flip_prob = h_flip_prob
175
+ self.v_flip_prob = v_flip_prob
176
+
177
+ # photometric augmentation params
178
+ self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5 / 3.14)
179
+
180
+ self.asymmetric_color_aug_prob = asymmetric_color_aug_prob
181
+
182
+ def color_transform(self, img1, img2):
183
+ """ Photometric augmentation """
184
+
185
+ # asymmetric
186
+ if np.random.rand() < self.asymmetric_color_aug_prob:
187
+ img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
188
+ img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
189
+
190
+ # symmetric
191
+ else:
192
+ image_stack = np.concatenate([img1, img2], axis=0)
193
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
194
+ img1, img2 = np.split(image_stack, 2, axis=0)
195
+
196
+ return img1, img2
197
+
198
+ def _resize_flow(self, flow, scale_x, scale_y, factor=1.0):
199
+ if np.all(np.isfinite(flow)):
200
+ flow = cv2.resize(flow, None, fx=scale_x/factor, fy=scale_y/factor, interpolation=cv2.INTER_LINEAR)
201
+ flow = flow * [scale_x, scale_y]
202
+ else: # sparse version
203
+ fx, fy = scale_x, scale_y
204
+ ht, wd = flow.shape[:2]
205
+ coords = np.meshgrid(np.arange(wd), np.arange(ht))
206
+ coords = np.stack(coords, axis=-1)
207
+
208
+ coords = coords.reshape(-1, 2).astype(np.float32)
209
+ flow = flow.reshape(-1, 2).astype(np.float32)
210
+ valid = np.isfinite(flow[:,0])
211
+
212
+ coords0 = coords[valid]
213
+ flow0 = flow[valid]
214
+
215
+ ht1 = int(round(ht * fy/factor))
216
+ wd1 = int(round(wd * fx/factor))
217
+
218
+ rescale = np.expand_dims(np.array([fx, fy]), axis=0)
219
+ coords1 = coords0 * rescale / factor
220
+ flow1 = flow0 * rescale
221
+
222
+ xx = np.round(coords1[:, 0]).astype(np.int32)
223
+ yy = np.round(coords1[:, 1]).astype(np.int32)
224
+
225
+ v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
226
+ xx = xx[v]
227
+ yy = yy[v]
228
+ flow1 = flow1[v]
229
+
230
+ flow = np.inf * np.ones([ht1, wd1, 2], dtype=np.float32) # invalid value every where, before we fill it with the correct ones
231
+ flow[yy, xx] = flow1
232
+ return flow
233
+
234
+ def spatial_transform(self, img1, img2, flow, dname):
235
+
236
+ if np.random.rand() < self.spatial_aug_prob:
237
+ # randomly sample scale
238
+ ht, wd = img1.shape[:2]
239
+ clip_min_scale = np.maximum(
240
+ (self.crop_size[0] + 8) / float(ht),
241
+ (self.crop_size[1] + 8) / float(wd))
242
+ min_scale, max_scale = self.min_scale, self.max_scale
243
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
244
+ scale_x = scale
245
+ scale_y = scale
246
+ if np.random.rand() < self.stretch_prob:
247
+ scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
248
+ scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
249
+ scale_x = np.clip(scale_x, clip_min_scale, None)
250
+ scale_y = np.clip(scale_y, clip_min_scale, None)
251
+ # rescale the images
252
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
253
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
254
+ flow = self._resize_flow(flow, scale_x, scale_y, factor=2.0 if dname=='Spring' else 1.0)
255
+ elif dname=="Spring":
256
+ flow = self._resize_flow(flow, 1.0, 1.0, factor=2.0)
257
+
258
+ if self.h_flip_prob>0. and np.random.rand() < self.h_flip_prob: # h-flip
259
+ img1 = img1[:, ::-1]
260
+ img2 = img2[:, ::-1]
261
+ flow = flow[:, ::-1] * [-1.0, 1.0]
262
+
263
+ if self.v_flip_prob>0. and np.random.rand() < self.v_flip_prob: # v-flip
264
+ img1 = img1[::-1, :]
265
+ img2 = img2[::-1, :]
266
+ flow = flow[::-1, :] * [1.0, -1.0]
267
+
268
+ # In case no cropping
269
+ if img1.shape[0] - self.crop_size[0] > 0:
270
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
271
+ else:
272
+ y0 = 0
273
+ if img1.shape[1] - self.crop_size[1] > 0:
274
+ x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
275
+ else:
276
+ x0 = 0
277
+
278
+ img1 = img1[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
279
+ img2 = img2[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
280
+ flow = flow[y0:y0 + self.crop_size[0], x0:x0 + self.crop_size[1]]
281
+
282
+ return img1, img2, flow
283
+
284
+ def __call__(self, img1, img2, flow, dname):
285
+ img1, img2, flow = self.spatial_transform(img1, img2, flow, dname)
286
+ img1, img2 = self.color_transform(img1, img2)
287
+ img1 = np.ascontiguousarray(img1)
288
+ img2 = np.ascontiguousarray(img2)
289
+ flow = np.ascontiguousarray(flow)
290
+ return img1, img2, flow
third_party/dust3r/croco/stereoflow/criterion.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Losses, metrics per batch, metrics per dataset
6
+ # --------------------------------------------------------
7
+
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+
12
+ def _get_gtnorm(gt):
13
+ if gt.size(1)==1: # stereo
14
+ return gt
15
+ # flow
16
+ return torch.sqrt(torch.sum(gt**2, dim=1, keepdims=True)) # Bx1xHxW
17
+
18
+ ############ losses without confidence
19
+
20
+ class L1Loss(nn.Module):
21
+
22
+ def __init__(self, max_gtnorm=None):
23
+ super().__init__()
24
+ self.max_gtnorm = max_gtnorm
25
+ self.with_conf = False
26
+
27
+ def _error(self, gt, predictions):
28
+ return torch.abs(gt-predictions)
29
+
30
+ def forward(self, predictions, gt, inspect=False):
31
+ mask = torch.isfinite(gt)
32
+ if self.max_gtnorm is not None:
33
+ mask *= _get_gtnorm(gt).expand(-1,gt.size(1),-1,-1)<self.max_gtnorm
34
+ if inspect:
35
+ return self._error(gt, predictions)
36
+ return self._error(gt[mask],predictions[mask]).mean()
37
+
38
+ ############## losses with confience
39
+ ## there are several parametrizations
40
+
41
+ class LaplacianLoss(nn.Module): # used for CroCo-Stereo on ETH3D, d'=exp(d)
42
+
43
+ def __init__(self, max_gtnorm=None):
44
+ super().__init__()
45
+ self.max_gtnorm = max_gtnorm
46
+ self.with_conf = True
47
+
48
+ def forward(self, predictions, gt, conf):
49
+ mask = torch.isfinite(gt)
50
+ mask = mask[:,0,:,:]
51
+ if self.max_gtnorm is not None: mask *= _get_gtnorm(gt)[:,0,:,:]<self.max_gtnorm
52
+ conf = conf.squeeze(1)
53
+ return ( torch.abs(gt-predictions).sum(dim=1)[mask] / torch.exp(conf[mask]) + conf[mask] ).mean()# + torch.log(2) => which is a constant
54
+
55
+
56
+ class LaplacianLossBounded(nn.Module): # used for CroCo-Flow ; in the equation of the paper, we have a=1/b
57
+ def __init__(self, max_gtnorm=10000., a=0.25, b=4.):
58
+ super().__init__()
59
+ self.max_gtnorm = max_gtnorm
60
+ self.with_conf = True
61
+ self.a, self.b = a, b
62
+
63
+ def forward(self, predictions, gt, conf):
64
+ mask = torch.isfinite(gt)
65
+ mask = mask[:,0,:,:]
66
+ if self.max_gtnorm is not None: mask *= _get_gtnorm(gt)[:,0,:,:]<self.max_gtnorm
67
+ conf = conf.squeeze(1)
68
+ conf = (self.b - self.a) * torch.sigmoid(conf) + self.a
69
+ return ( torch.abs(gt-predictions).sum(dim=1)[mask] / conf[mask] + torch.log(conf)[mask] ).mean()# + torch.log(2) => which is a constant
70
+
71
+ class LaplacianLossBounded2(nn.Module): # used for CroCo-Stereo (except for ETH3D) ; in the equation of the paper, we have a=b
72
+ def __init__(self, max_gtnorm=None, a=3.0, b=3.0):
73
+ super().__init__()
74
+ self.max_gtnorm = max_gtnorm
75
+ self.with_conf = True
76
+ self.a, self.b = a, b
77
+
78
+ def forward(self, predictions, gt, conf):
79
+ mask = torch.isfinite(gt)
80
+ mask = mask[:,0,:,:]
81
+ if self.max_gtnorm is not None: mask *= _get_gtnorm(gt)[:,0,:,:]<self.max_gtnorm
82
+ conf = conf.squeeze(1)
83
+ conf = 2 * self.a * (torch.sigmoid(conf / self.b) - 0.5 )
84
+ return ( torch.abs(gt-predictions).sum(dim=1)[mask] / torch.exp(conf[mask]) + conf[mask] ).mean()# + torch.log(2) => which is a constant
85
+
86
+ ############## metrics per batch
87
+
88
+ class StereoMetrics(nn.Module):
89
+
90
+ def __init__(self, do_quantile=False):
91
+ super().__init__()
92
+ self.bad_ths = [0.5,1,2,3]
93
+ self.do_quantile = do_quantile
94
+
95
+ def forward(self, predictions, gt):
96
+ B = predictions.size(0)
97
+ metrics = {}
98
+ gtcopy = gt.clone()
99
+ mask = torch.isfinite(gtcopy)
100
+ gtcopy[~mask] = 999999.0 # we make a copy and put a non-infinite value, such that it does not become nan once multiplied by the mask value 0
101
+ Npx = mask.view(B,-1).sum(dim=1)
102
+ L1error = (torch.abs(gtcopy-predictions)*mask).view(B,-1)
103
+ L2error = (torch.square(gtcopy-predictions)*mask).view(B,-1)
104
+ # avgerr
105
+ metrics['avgerr'] = torch.mean(L1error.sum(dim=1)/Npx )
106
+ # rmse
107
+ metrics['rmse'] = torch.sqrt(L2error.sum(dim=1)/Npx).mean(dim=0)
108
+ # err > t for t in [0.5,1,2,3]
109
+ for ths in self.bad_ths:
110
+ metrics['bad@{:.1f}'.format(ths)] = (((L1error>ths)* mask.view(B,-1)).sum(dim=1)/Npx).mean(dim=0) * 100
111
+ return metrics
112
+
113
+ class FlowMetrics(nn.Module):
114
+ def __init__(self):
115
+ super().__init__()
116
+ self.bad_ths = [1,3,5]
117
+
118
+ def forward(self, predictions, gt):
119
+ B = predictions.size(0)
120
+ metrics = {}
121
+ mask = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite
122
+ Npx = mask.view(B,-1).sum(dim=1)
123
+ gtcopy = gt.clone() # to compute L1/L2 error, we need to have non-infinite value, the error computed at this locations will be ignored
124
+ gtcopy[:,0,:,:][~mask] = 999999.0
125
+ gtcopy[:,1,:,:][~mask] = 999999.0
126
+ L1error = (torch.abs(gtcopy-predictions).sum(dim=1)*mask).view(B,-1)
127
+ L2error = (torch.sqrt(torch.sum(torch.square(gtcopy-predictions),dim=1))*mask).view(B,-1)
128
+ metrics['L1err'] = torch.mean(L1error.sum(dim=1)/Npx )
129
+ metrics['EPE'] = torch.mean(L2error.sum(dim=1)/Npx )
130
+ for ths in self.bad_ths:
131
+ metrics['bad@{:.1f}'.format(ths)] = (((L2error>ths)* mask.view(B,-1)).sum(dim=1)/Npx).mean(dim=0) * 100
132
+ return metrics
133
+
134
+ ############## metrics per dataset
135
+ ## we update the average and maintain the number of pixels while adding data batch per batch
136
+ ## at the beggining, call reset()
137
+ ## after each batch, call add_batch(...)
138
+ ## at the end: call get_results()
139
+
140
+ class StereoDatasetMetrics(nn.Module):
141
+
142
+ def __init__(self):
143
+ super().__init__()
144
+ self.bad_ths = [0.5,1,2,3]
145
+
146
+ def reset(self):
147
+ self.agg_N = 0 # number of pixels so far
148
+ self.agg_L1err = torch.tensor(0.0) # L1 error so far
149
+ self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels
150
+ self._metrics = None
151
+
152
+ def add_batch(self, predictions, gt):
153
+ assert predictions.size(1)==1, predictions.size()
154
+ assert gt.size(1)==1, gt.size()
155
+ if gt.size(2)==predictions.size(2)*2 and gt.size(3)==predictions.size(3)*2: # special case for Spring ...
156
+ L1err = torch.minimum( torch.minimum( torch.minimum(
157
+ torch.sum(torch.abs(gt[:,:,0::2,0::2]-predictions),dim=1),
158
+ torch.sum(torch.abs(gt[:,:,1::2,0::2]-predictions),dim=1)),
159
+ torch.sum(torch.abs(gt[:,:,0::2,1::2]-predictions),dim=1)),
160
+ torch.sum(torch.abs(gt[:,:,1::2,1::2]-predictions),dim=1))
161
+ valid = torch.isfinite(L1err)
162
+ else:
163
+ valid = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite
164
+ L1err = torch.sum(torch.abs(gt-predictions),dim=1)
165
+ N = valid.sum()
166
+ Nnew = self.agg_N + N
167
+ self.agg_L1err = float(self.agg_N)/Nnew * self.agg_L1err + L1err[valid].mean().cpu() * float(N)/Nnew
168
+ self.agg_N = Nnew
169
+ for i,th in enumerate(self.bad_ths):
170
+ self.agg_Nbad[i] += (L1err[valid]>th).sum().cpu()
171
+
172
+ def _compute_metrics(self):
173
+ if self._metrics is not None: return
174
+ out = {}
175
+ out['L1err'] = self.agg_L1err.item()
176
+ for i,th in enumerate(self.bad_ths):
177
+ out['bad@{:.1f}'.format(th)] = (float(self.agg_Nbad[i]) / self.agg_N).item() * 100.0
178
+ self._metrics = out
179
+
180
+ def get_results(self):
181
+ self._compute_metrics() # to avoid recompute them multiple times
182
+ return self._metrics
183
+
184
+ class FlowDatasetMetrics(nn.Module):
185
+
186
+ def __init__(self):
187
+ super().__init__()
188
+ self.bad_ths = [0.5,1,3,5]
189
+ self.speed_ths = [(0,10),(10,40),(40,torch.inf)]
190
+
191
+ def reset(self):
192
+ self.agg_N = 0 # number of pixels so far
193
+ self.agg_L1err = torch.tensor(0.0) # L1 error so far
194
+ self.agg_L2err = torch.tensor(0.0) # L2 (=EPE) error so far
195
+ self.agg_Nbad = [0 for _ in self.bad_ths] # counter of bad pixels
196
+ self.agg_EPEspeed = [torch.tensor(0.0) for _ in self.speed_ths] # EPE per speed bin so far
197
+ self.agg_Nspeed = [0 for _ in self.speed_ths] # N pixels per speed bin so far
198
+ self._metrics = None
199
+ self.pairname_results = {}
200
+
201
+ def add_batch(self, predictions, gt):
202
+ assert predictions.size(1)==2, predictions.size()
203
+ assert gt.size(1)==2, gt.size()
204
+ if gt.size(2)==predictions.size(2)*2 and gt.size(3)==predictions.size(3)*2: # special case for Spring ...
205
+ L1err = torch.minimum( torch.minimum( torch.minimum(
206
+ torch.sum(torch.abs(gt[:,:,0::2,0::2]-predictions),dim=1),
207
+ torch.sum(torch.abs(gt[:,:,1::2,0::2]-predictions),dim=1)),
208
+ torch.sum(torch.abs(gt[:,:,0::2,1::2]-predictions),dim=1)),
209
+ torch.sum(torch.abs(gt[:,:,1::2,1::2]-predictions),dim=1))
210
+ L2err = torch.minimum( torch.minimum( torch.minimum(
211
+ torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,0::2]-predictions),dim=1)),
212
+ torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,0::2]-predictions),dim=1))),
213
+ torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,1::2]-predictions),dim=1))),
214
+ torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,1::2]-predictions),dim=1)))
215
+ valid = torch.isfinite(L1err)
216
+ gtspeed = (torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,0::2]),dim=1)) + torch.sqrt(torch.sum(torch.square(gt[:,:,0::2,1::2]),dim=1)) +\
217
+ torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,0::2]),dim=1)) + torch.sqrt(torch.sum(torch.square(gt[:,:,1::2,1::2]),dim=1)) ) / 4.0 # let's just average them
218
+ else:
219
+ valid = torch.isfinite(gt[:,0,:,:]) # both x and y would be infinite
220
+ L1err = torch.sum(torch.abs(gt-predictions),dim=1)
221
+ L2err = torch.sqrt(torch.sum(torch.square(gt-predictions),dim=1))
222
+ gtspeed = torch.sqrt(torch.sum(torch.square(gt),dim=1))
223
+ N = valid.sum()
224
+ Nnew = self.agg_N + N
225
+ self.agg_L1err = float(self.agg_N)/Nnew * self.agg_L1err + L1err[valid].mean().cpu() * float(N)/Nnew
226
+ self.agg_L2err = float(self.agg_N)/Nnew * self.agg_L2err + L2err[valid].mean().cpu() * float(N)/Nnew
227
+ self.agg_N = Nnew
228
+ for i,th in enumerate(self.bad_ths):
229
+ self.agg_Nbad[i] += (L2err[valid]>th).sum().cpu()
230
+ for i,(th1,th2) in enumerate(self.speed_ths):
231
+ vv = (gtspeed[valid]>=th1) * (gtspeed[valid]<th2)
232
+ iNspeed = vv.sum()
233
+ if iNspeed==0: continue
234
+ iNnew = self.agg_Nspeed[i] + iNspeed
235
+ self.agg_EPEspeed[i] = float(self.agg_Nspeed[i]) / iNnew * self.agg_EPEspeed[i] + float(iNspeed) / iNnew * L2err[valid][vv].mean().cpu()
236
+ self.agg_Nspeed[i] = iNnew
237
+
238
+ def _compute_metrics(self):
239
+ if self._metrics is not None: return
240
+ out = {}
241
+ out['L1err'] = self.agg_L1err.item()
242
+ out['EPE'] = self.agg_L2err.item()
243
+ for i,th in enumerate(self.bad_ths):
244
+ out['bad@{:.1f}'.format(th)] = (float(self.agg_Nbad[i]) / self.agg_N).item() * 100.0
245
+ for i,(th1,th2) in enumerate(self.speed_ths):
246
+ out['s{:d}{:s}'.format(th1, '-'+str(th2) if th2<torch.inf else '+')] = self.agg_EPEspeed[i].item()
247
+ self._metrics = out
248
+
249
+ def get_results(self):
250
+ self._compute_metrics() # to avoid recompute them multiple times
251
+ return self._metrics
third_party/dust3r/croco/stereoflow/datasets_flow.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Dataset structure for flow
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import os.path as osp
10
+ import pickle
11
+ import numpy as np
12
+ import struct
13
+ from PIL import Image
14
+ import json
15
+ import h5py
16
+ import torch
17
+ from torch.utils import data
18
+
19
+ from .augmentor import FlowAugmentor
20
+ from .datasets_stereo import _read_img, img_to_tensor, dataset_to_root, _read_pfm
21
+ from copy import deepcopy
22
+ dataset_to_root = deepcopy(dataset_to_root)
23
+
24
+ dataset_to_root.update(**{
25
+ 'TartanAir': './data/stereoflow/TartanAir',
26
+ 'FlyingChairs': './data/stereoflow/FlyingChairs/',
27
+ 'FlyingThings': osp.join(dataset_to_root['SceneFlow'],'FlyingThings')+'/',
28
+ 'MPISintel': './data/stereoflow//MPI-Sintel/'+'/',
29
+ })
30
+ cache_dir = "./data/stereoflow/datasets_flow_cache/"
31
+
32
+
33
+ def flow_to_tensor(disp):
34
+ return torch.from_numpy(disp).float().permute(2, 0, 1)
35
+
36
+ class FlowDataset(data.Dataset):
37
+
38
+ def __init__(self, split, augmentor=False, crop_size=None, totensor=True):
39
+ self.split = split
40
+ if not augmentor: assert crop_size is None
41
+ if crop_size is not None: assert augmentor
42
+ self.crop_size = crop_size
43
+ self.augmentor_str = augmentor
44
+ self.augmentor = FlowAugmentor(crop_size) if augmentor else None
45
+ self.totensor = totensor
46
+ self.rmul = 1 # keep track of rmul
47
+ self.has_constant_resolution = True # whether the dataset has constant resolution or not (=> don't use batch_size>1 at test time)
48
+ self._prepare_data()
49
+ self._load_or_build_cache()
50
+
51
+ def prepare_data(self):
52
+ """
53
+ to be defined for each dataset
54
+ """
55
+ raise NotImplementedError
56
+
57
+ def __len__(self):
58
+ return len(self.pairnames) # each pairname is typically of the form (str, int1, int2)
59
+
60
+ def __getitem__(self, index):
61
+ pairname = self.pairnames[index]
62
+
63
+ # get filenames
64
+ img1name = self.pairname_to_img1name(pairname)
65
+ img2name = self.pairname_to_img2name(pairname)
66
+ flowname = self.pairname_to_flowname(pairname) if self.pairname_to_flowname is not None else None
67
+
68
+ # load images and disparities
69
+ img1 = _read_img(img1name)
70
+ img2 = _read_img(img2name)
71
+ flow = self.load_flow(flowname) if flowname is not None else None
72
+
73
+ # apply augmentations
74
+ if self.augmentor is not None:
75
+ img1, img2, flow = self.augmentor(img1, img2, flow, self.name)
76
+
77
+ if self.totensor:
78
+ img1 = img_to_tensor(img1)
79
+ img2 = img_to_tensor(img2)
80
+ if flow is not None:
81
+ flow = flow_to_tensor(flow)
82
+ else:
83
+ flow = torch.tensor([]) # to allow dataloader batching with default collate_gn
84
+ pairname = str(pairname) # transform potential tuple to str to be able to batch it
85
+
86
+ return img1, img2, flow, pairname
87
+
88
+ def __rmul__(self, v):
89
+ self.rmul *= v
90
+ self.pairnames = v * self.pairnames
91
+ return self
92
+
93
+ def __str__(self):
94
+ return f'{self.__class__.__name__}_{self.split}'
95
+
96
+ def __repr__(self):
97
+ s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})'
98
+ if self.rmul==1:
99
+ s+=f'\n\tnum pairs: {len(self.pairnames)}'
100
+ else:
101
+ s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})'
102
+ return s
103
+
104
+ def _set_root(self):
105
+ self.root = dataset_to_root[self.name]
106
+ assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}"
107
+
108
+ def _load_or_build_cache(self):
109
+ cache_file = osp.join(cache_dir, self.name+'.pkl')
110
+ if osp.isfile(cache_file):
111
+ with open(cache_file, 'rb') as fid:
112
+ self.pairnames = pickle.load(fid)[self.split]
113
+ else:
114
+ tosave = self._build_cache()
115
+ os.makedirs(cache_dir, exist_ok=True)
116
+ with open(cache_file, 'wb') as fid:
117
+ pickle.dump(tosave, fid)
118
+ self.pairnames = tosave[self.split]
119
+
120
+ class TartanAirDataset(FlowDataset):
121
+
122
+ def _prepare_data(self):
123
+ self.name = "TartanAir"
124
+ self._set_root()
125
+ assert self.split in ['train']
126
+ self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], 'image_left/{:06d}_left.png'.format(pairname[1]))
127
+ self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], 'image_left/{:06d}_left.png'.format(pairname[2]))
128
+ self.pairname_to_flowname = lambda pairname: osp.join(self.root, pairname[0], 'flow/{:06d}_{:06d}_flow.npy'.format(pairname[1],pairname[2]))
129
+ self.pairname_to_str = lambda pairname: os.path.join(pairname[0][pairname[0].find('/')+1:], '{:06d}_{:06d}'.format(pairname[1], pairname[2]))
130
+ self.load_flow = _read_numpy_flow
131
+
132
+ def _build_cache(self):
133
+ seqs = sorted(os.listdir(self.root))
134
+ pairs = [(osp.join(s,s,difficulty,Pxxx),int(a[:6]),int(a[:6])+1) for s in seqs for difficulty in ['Easy','Hard'] for Pxxx in sorted(os.listdir(osp.join(self.root,s,s,difficulty))) for a in sorted(os.listdir(osp.join(self.root,s,s,difficulty,Pxxx,'image_left/')))[:-1]]
135
+ assert len(pairs)==306268, "incorrect parsing of pairs in TartanAir"
136
+ tosave = {'train': pairs}
137
+ return tosave
138
+
139
+ class FlyingChairsDataset(FlowDataset):
140
+
141
+ def _prepare_data(self):
142
+ self.name = "FlyingChairs"
143
+ self._set_root()
144
+ assert self.split in ['train','val']
145
+ self.pairname_to_img1name = lambda pairname: osp.join(self.root, 'data', pairname+'_img1.ppm')
146
+ self.pairname_to_img2name = lambda pairname: osp.join(self.root, 'data', pairname+'_img2.ppm')
147
+ self.pairname_to_flowname = lambda pairname: osp.join(self.root, 'data', pairname+'_flow.flo')
148
+ self.pairname_to_str = lambda pairname: pairname
149
+ self.load_flow = _read_flo_file
150
+
151
+ def _build_cache(self):
152
+ split_file = osp.join(self.root, 'chairs_split.txt')
153
+ split_list = np.loadtxt(split_file, dtype=np.int32)
154
+ trainpairs = ['{:05d}'.format(i) for i in np.where(split_list==1)[0]+1]
155
+ valpairs = ['{:05d}'.format(i) for i in np.where(split_list==2)[0]+1]
156
+ assert len(trainpairs)==22232 and len(valpairs)==640, "incorrect parsing of pairs in MPI-Sintel"
157
+ tosave = {'train': trainpairs, 'val': valpairs}
158
+ return tosave
159
+
160
+ class FlyingThingsDataset(FlowDataset):
161
+
162
+ def _prepare_data(self):
163
+ self.name = "FlyingThings"
164
+ self._set_root()
165
+ assert self.split in [f'{set_}_{pass_}pass{camstr}' for set_ in ['train','test','test1024'] for camstr in ['','_rightcam'] for pass_ in ['clean','final','all']]
166
+ self.pairname_to_img1name = lambda pairname: osp.join(self.root, f'frames_{pairname[3]}pass', pairname[0].replace('into_future','').replace('into_past',''), '{:04d}.png'.format(pairname[1]))
167
+ self.pairname_to_img2name = lambda pairname: osp.join(self.root, f'frames_{pairname[3]}pass', pairname[0].replace('into_future','').replace('into_past',''), '{:04d}.png'.format(pairname[2]))
168
+ self.pairname_to_flowname = lambda pairname: osp.join(self.root, 'optical_flow', pairname[0], 'OpticalFlowInto{f:s}_{i:04d}_{c:s}.pfm'.format(f='Future' if 'future' in pairname[0] else 'Past', i=pairname[1], c='L' if 'left' in pairname[0] else 'R' ))
169
+ self.pairname_to_str = lambda pairname: os.path.join(pairname[3]+'pass', pairname[0], 'Into{f:s}_{i:04d}_{c:s}'.format(f='Future' if 'future' in pairname[0] else 'Past', i=pairname[1], c='L' if 'left' in pairname[0] else 'R' ))
170
+ self.load_flow = _read_pfm_flow
171
+
172
+ def _build_cache(self):
173
+ tosave = {}
174
+ # train and test splits for the different passes
175
+ for set_ in ['train', 'test']:
176
+ sroot = osp.join(self.root, 'optical_flow', set_.upper())
177
+ fname_to_i = lambda f: int(f[len('OpticalFlowIntoFuture_'):-len('_L.pfm')])
178
+ pp = [(osp.join(set_.upper(), d, s, 'into_future/left'),fname_to_i(fname)) for d in sorted(os.listdir(sroot)) for s in sorted(os.listdir(osp.join(sroot,d))) for fname in sorted(os.listdir(osp.join(sroot,d, s, 'into_future/left')))[:-1]]
179
+ pairs = [(a,i,i+1) for a,i in pp]
180
+ pairs += [(a.replace('into_future','into_past'),i+1,i) for a,i in pp]
181
+ assert len(pairs)=={'train': 40302, 'test': 7866}[set_], "incorrect parsing of pairs Flying Things"
182
+ for cam in ['left','right']:
183
+ camstr = '' if cam=='left' else f'_{cam}cam'
184
+ for pass_ in ['final', 'clean']:
185
+ tosave[f'{set_}_{pass_}pass{camstr}'] = [(a.replace('left',cam),i,j,pass_) for a,i,j in pairs]
186
+ tosave[f'{set_}_allpass{camstr}'] = tosave[f'{set_}_cleanpass{camstr}'] + tosave[f'{set_}_finalpass{camstr}']
187
+ # test1024: this is the same split as unimatch 'validation' split
188
+ # see https://github.com/autonomousvision/unimatch/blob/master/dataloader/flow/datasets.py#L229
189
+ test1024_nsamples = 1024
190
+ alltest_nsamples = len(tosave['test_cleanpass']) # 7866
191
+ stride = alltest_nsamples // test1024_nsamples
192
+ remove = alltest_nsamples % test1024_nsamples
193
+ for cam in ['left','right']:
194
+ camstr = '' if cam=='left' else f'_{cam}cam'
195
+ for pass_ in ['final','clean']:
196
+ tosave[f'test1024_{pass_}pass{camstr}'] = sorted(tosave[f'test_{pass_}pass{camstr}'])[:-remove][::stride] # warning, it was not sorted before
197
+ assert len(tosave['test1024_cleanpass'])==1024, "incorrect parsing of pairs in Flying Things"
198
+ tosave[f'test1024_allpass{camstr}'] = tosave[f'test1024_cleanpass{camstr}'] + tosave[f'test1024_finalpass{camstr}']
199
+ return tosave
200
+
201
+
202
+ class MPISintelDataset(FlowDataset):
203
+
204
+ def _prepare_data(self):
205
+ self.name = "MPISintel"
206
+ self._set_root()
207
+ assert self.split in [s+'_'+p for s in ['train','test','subval','subtrain'] for p in ['cleanpass','finalpass','allpass']]
208
+ self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], 'frame_{:04d}.png'.format(pairname[1]))
209
+ self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], 'frame_{:04d}.png'.format(pairname[1]+1))
210
+ self.pairname_to_flowname = lambda pairname: None if pairname[0].startswith('test/') else osp.join(self.root, pairname[0].replace('/clean/','/flow/').replace('/final/','/flow/'), 'frame_{:04d}.flo'.format(pairname[1]))
211
+ self.pairname_to_str = lambda pairname: osp.join(pairname[0], 'frame_{:04d}'.format(pairname[1]))
212
+ self.load_flow = _read_flo_file
213
+
214
+ def _build_cache(self):
215
+ trainseqs = sorted(os.listdir(self.root+'training/clean'))
216
+ trainpairs = [ (osp.join('training/clean', s),i) for s in trainseqs for i in range(1, len(os.listdir(self.root+'training/clean/'+s)))]
217
+ subvalseqs = ['temple_2','temple_3']
218
+ subtrainseqs = [s for s in trainseqs if s not in subvalseqs]
219
+ subvalpairs = [ (p,i) for p,i in trainpairs if any(s in p for s in subvalseqs)]
220
+ subtrainpairs = [ (p,i) for p,i in trainpairs if any(s in p for s in subtrainseqs)]
221
+ testseqs = sorted(os.listdir(self.root+'test/clean'))
222
+ testpairs = [ (osp.join('test/clean', s),i) for s in testseqs for i in range(1, len(os.listdir(self.root+'test/clean/'+s)))]
223
+ assert len(trainpairs)==1041 and len(testpairs)==552 and len(subvalpairs)==98 and len(subtrainpairs)==943, "incorrect parsing of pairs in MPI-Sintel"
224
+ tosave = {}
225
+ tosave['train_cleanpass'] = trainpairs
226
+ tosave['test_cleanpass'] = testpairs
227
+ tosave['subval_cleanpass'] = subvalpairs
228
+ tosave['subtrain_cleanpass'] = subtrainpairs
229
+ for t in ['train','test','subval','subtrain']:
230
+ tosave[t+'_finalpass'] = [(p.replace('/clean/','/final/'),i) for p,i in tosave[t+'_cleanpass']]
231
+ tosave[t+'_allpass'] = tosave[t+'_cleanpass'] + tosave[t+'_finalpass']
232
+ return tosave
233
+
234
+ def submission_save_pairname(self, pairname, prediction, outdir, _time):
235
+ assert prediction.shape[2]==2
236
+ outfile = os.path.join(outdir, 'submission', self.pairname_to_str(pairname)+'.flo')
237
+ os.makedirs( os.path.dirname(outfile), exist_ok=True)
238
+ writeFlowFile(prediction, outfile)
239
+
240
+ def finalize_submission(self, outdir):
241
+ assert self.split == 'test_allpass'
242
+ bundle_exe = "/nfs/data/ffs-3d/datasets/StereoFlow/MPI-Sintel/bundler/linux-x64/bundler" # eg <bundle_exe> <path_to_results_for_clean> <path_to_results_for_final> <output/bundled.lzma>
243
+ if os.path.isfile(bundle_exe):
244
+ cmd = f'{bundle_exe} "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"'
245
+ print(cmd)
246
+ os.system(cmd)
247
+ print(f'Done. Submission file at: "{outdir}/submission/bundled.lzma"')
248
+ else:
249
+ print('Could not find bundler executable for submission.')
250
+ print('Please download it and run:')
251
+ print(f'<bundle_exe> "{outdir}/submission/test/clean/" "{outdir}/submission/test/final" "{outdir}/submission/bundled.lzma"')
252
+
253
+ class SpringDataset(FlowDataset):
254
+
255
+ def _prepare_data(self):
256
+ self.name = "Spring"
257
+ self._set_root()
258
+ assert self.split in ['train','test','subtrain','subval']
259
+ self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname[0], pairname[1], 'frame_'+pairname[3], 'frame_{:s}_{:04d}.png'.format(pairname[3], pairname[4]))
260
+ self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname[0], pairname[1], 'frame_'+pairname[3], 'frame_{:s}_{:04d}.png'.format(pairname[3], pairname[4]+(1 if pairname[2]=='FW' else -1)))
261
+ self.pairname_to_flowname = lambda pairname: None if pairname[0]=='test' else osp.join(self.root, pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5')
262
+ self.pairname_to_str = lambda pairname: osp.join(pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}')
263
+ self.load_flow = _read_hdf5_flow
264
+
265
+ def _build_cache(self):
266
+ # train
267
+ trainseqs = sorted(os.listdir( osp.join(self.root,'train')))
268
+ trainpairs = []
269
+ for leftright in ['left','right']:
270
+ for fwbw in ['FW','BW']:
271
+ trainpairs += [('train',s,fwbw,leftright,int(f[len(f'flow_{fwbw}_{leftright}_'):-len('.flo5')])) for s in trainseqs for f in sorted(os.listdir(osp.join(self.root,'train',s,f'flow_{fwbw}_{leftright}')))]
272
+ # test
273
+ testseqs = sorted(os.listdir( osp.join(self.root,'test')))
274
+ testpairs = []
275
+ for leftright in ['left','right']:
276
+ testpairs += [('test',s,'FW',leftright,int(f[len(f'frame_{leftright}_'):-len('.png')])) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,f'frame_{leftright}')))[:-1]]
277
+ testpairs += [('test',s,'BW',leftright,int(f[len(f'frame_{leftright}_'):-len('.png')])+1) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,f'frame_{leftright}')))[:-1]]
278
+ # subtrain / subval
279
+ subtrainpairs = [p for p in trainpairs if p[1]!='0041']
280
+ subvalpairs = [p for p in trainpairs if p[1]=='0041']
281
+ assert len(trainpairs)==19852 and len(testpairs)==3960 and len(subtrainpairs)==19472 and len(subvalpairs)==380, "incorrect parsing of pairs in Spring"
282
+ tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs}
283
+ return tosave
284
+
285
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
286
+ assert prediction.ndim==3
287
+ assert prediction.shape[2]==2
288
+ assert prediction.dtype==np.float32
289
+ outfile = osp.join(outdir, pairname[0], pairname[1], f'flow_{pairname[2]}_{pairname[3]}', f'flow_{pairname[2]}_{pairname[3]}_{pairname[4]:04d}.flo5')
290
+ os.makedirs( os.path.dirname(outfile), exist_ok=True)
291
+ writeFlo5File(prediction, outfile)
292
+
293
+ def finalize_submission(self, outdir):
294
+ assert self.split=='test'
295
+ exe = "{self.root}/flow_subsampling"
296
+ if os.path.isfile(exe):
297
+ cmd = f'cd "{outdir}/test"; {exe} .'
298
+ print(cmd)
299
+ os.system(cmd)
300
+ print(f'Done. Submission file at {outdir}/test/flow_submission.hdf5')
301
+ else:
302
+ print('Could not find flow_subsampling executable for submission.')
303
+ print('Please download it and run:')
304
+ print(f'cd "{outdir}/test"; <flow_subsampling_exe> .')
305
+
306
+
307
+ class Kitti12Dataset(FlowDataset):
308
+
309
+ def _prepare_data(self):
310
+ self.name = "Kitti12"
311
+ self._set_root()
312
+ assert self.split in ['train','test']
313
+ self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname+'_10.png')
314
+ self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname+'_11.png')
315
+ self.pairname_to_flowname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/flow_occ/')+'_10.png')
316
+ self.pairname_to_str = lambda pairname: pairname.replace('/colored_0/','/')
317
+ self.load_flow = _read_kitti_flow
318
+
319
+ def _build_cache(self):
320
+ trainseqs = ["training/colored_0/%06d"%(i) for i in range(194)]
321
+ testseqs = ["testing/colored_0/%06d"%(i) for i in range(195)]
322
+ assert len(trainseqs)==194 and len(testseqs)==195, "incorrect parsing of pairs in Kitti12"
323
+ tosave = {'train': trainseqs, 'test': testseqs}
324
+ return tosave
325
+
326
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
327
+ assert prediction.ndim==3
328
+ assert prediction.shape[2]==2
329
+ outfile = os.path.join(outdir, pairname.split('/')[-1]+'_10.png')
330
+ os.makedirs( os.path.dirname(outfile), exist_ok=True)
331
+ writeFlowKitti(outfile, prediction)
332
+
333
+ def finalize_submission(self, outdir):
334
+ assert self.split=='test'
335
+ cmd = f'cd {outdir}/; zip -r "kitti12_flow_results.zip" .'
336
+ print(cmd)
337
+ os.system(cmd)
338
+ print(f'Done. Submission file at {outdir}/kitti12_flow_results.zip')
339
+
340
+
341
+ class Kitti15Dataset(FlowDataset):
342
+
343
+ def _prepare_data(self):
344
+ self.name = "Kitti15"
345
+ self._set_root()
346
+ assert self.split in ['train','subtrain','subval','test']
347
+ self.pairname_to_img1name = lambda pairname: osp.join(self.root, pairname+'_10.png')
348
+ self.pairname_to_img2name = lambda pairname: osp.join(self.root, pairname+'_11.png')
349
+ self.pairname_to_flowname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/flow_occ/')+'_10.png')
350
+ self.pairname_to_str = lambda pairname: pairname.replace('/image_2/','/')
351
+ self.load_flow = _read_kitti_flow
352
+
353
+ def _build_cache(self):
354
+ trainseqs = ["training/image_2/%06d"%(i) for i in range(200)]
355
+ subtrainseqs = trainseqs[:-10]
356
+ subvalseqs = trainseqs[-10:]
357
+ testseqs = ["testing/image_2/%06d"%(i) for i in range(200)]
358
+ assert len(trainseqs)==200 and len(subtrainseqs)==190 and len(subvalseqs)==10 and len(testseqs)==200, "incorrect parsing of pairs in Kitti15"
359
+ tosave = {'train': trainseqs, 'subtrain': subtrainseqs, 'subval': subvalseqs, 'test': testseqs}
360
+ return tosave
361
+
362
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
363
+ assert prediction.ndim==3
364
+ assert prediction.shape[2]==2
365
+ outfile = os.path.join(outdir, 'flow', pairname.split('/')[-1]+'_10.png')
366
+ os.makedirs( os.path.dirname(outfile), exist_ok=True)
367
+ writeFlowKitti(outfile, prediction)
368
+
369
+ def finalize_submission(self, outdir):
370
+ assert self.split=='test'
371
+ cmd = f'cd {outdir}/; zip -r "kitti15_flow_results.zip" flow'
372
+ print(cmd)
373
+ os.system(cmd)
374
+ print(f'Done. Submission file at {outdir}/kitti15_flow_results.zip')
375
+
376
+
377
+ import cv2
378
+ def _read_numpy_flow(filename):
379
+ return np.load(filename)
380
+
381
+ def _read_pfm_flow(filename):
382
+ f, _ = _read_pfm(filename)
383
+ assert np.all(f[:,:,2]==0.0)
384
+ return np.ascontiguousarray(f[:,:,:2])
385
+
386
+ TAG_FLOAT = 202021.25 # tag to check the sanity of the file
387
+ TAG_STRING = 'PIEH' # string containing the tag
388
+ MIN_WIDTH = 1
389
+ MAX_WIDTH = 99999
390
+ MIN_HEIGHT = 1
391
+ MAX_HEIGHT = 99999
392
+ def readFlowFile(filename):
393
+ """
394
+ readFlowFile(<FILENAME>) reads a flow file <FILENAME> into a 2-band np.array.
395
+ if <FILENAME> does not exist, an IOError is raised.
396
+ if <FILENAME> does not finish by '.flo' or the tag, the width, the height or the file's size is illegal, an Expcetion is raised.
397
+ ---- PARAMETERS ----
398
+ filename: string containg the name of the file to read a flow
399
+ ---- OUTPUTS ----
400
+ a np.array of dimension (height x width x 2) containing the flow of type 'float32'
401
+ """
402
+
403
+ # check filename
404
+ if not filename.endswith(".flo"):
405
+ raise Exception("readFlowFile({:s}): filename must finish with '.flo'".format(filename))
406
+
407
+ # open the file and read it
408
+ with open(filename,'rb') as f:
409
+ # check tag
410
+ tag = struct.unpack('f',f.read(4))[0]
411
+ if tag != TAG_FLOAT:
412
+ raise Exception("flow_utils.readFlowFile({:s}): wrong tag".format(filename))
413
+ # read dimension
414
+ w,h = struct.unpack('ii',f.read(8))
415
+ if w < MIN_WIDTH or w > MAX_WIDTH:
416
+ raise Exception("flow_utils.readFlowFile({:s}: illegal width {:d}".format(filename,w))
417
+ if h < MIN_HEIGHT or h > MAX_HEIGHT:
418
+ raise Exception("flow_utils.readFlowFile({:s}: illegal height {:d}".format(filename,h))
419
+ flow = np.fromfile(f,'float32')
420
+ if not flow.shape == (h*w*2,):
421
+ raise Exception("flow_utils.readFlowFile({:s}: illegal size of the file".format(filename))
422
+ flow.shape = (h,w,2)
423
+ return flow
424
+
425
+ def writeFlowFile(flow,filename):
426
+ """
427
+ writeFlowFile(flow,<FILENAME>) write flow to the file <FILENAME>.
428
+ if <FILENAME> does not exist, an IOError is raised.
429
+ if <FILENAME> does not finish with '.flo' or the flow has not 2 bands, an Exception is raised.
430
+ ---- PARAMETERS ----
431
+ flow: np.array of dimension (height x width x 2) containing the flow to write
432
+ filename: string containg the name of the file to write a flow
433
+ """
434
+
435
+ # check filename
436
+ if not filename.endswith(".flo"):
437
+ raise Exception("flow_utils.writeFlowFile(<flow>,{:s}): filename must finish with '.flo'".format(filename))
438
+
439
+ if not flow.shape[2:] == (2,):
440
+ raise Exception("flow_utils.writeFlowFile(<flow>,{:s}): <flow> must have 2 bands".format(filename))
441
+
442
+
443
+ # open the file and write it
444
+ with open(filename,'wb') as f:
445
+ # write TAG
446
+ f.write( TAG_STRING.encode('utf-8') )
447
+ # write dimension
448
+ f.write( struct.pack('ii',flow.shape[1],flow.shape[0]) )
449
+ # write the flow
450
+
451
+ flow.astype(np.float32).tofile(f)
452
+
453
+ _read_flo_file = readFlowFile
454
+
455
+ def _read_kitti_flow(filename):
456
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR)
457
+ flow = flow[:, :, ::-1].astype(np.float32)
458
+ valid = flow[:, :, 2]>0
459
+ flow = flow[:, :, :2]
460
+ flow = (flow - 2 ** 15) / 64.0
461
+ flow[~valid,0] = np.inf
462
+ flow[~valid,1] = np.inf
463
+ return flow
464
+ _read_hd1k_flow = _read_kitti_flow
465
+
466
+
467
+ def writeFlowKitti(filename, uv):
468
+ uv = 64.0 * uv + 2 ** 15
469
+ valid = np.ones([uv.shape[0], uv.shape[1], 1])
470
+ uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
471
+ cv2.imwrite(filename, uv[..., ::-1])
472
+
473
+ def writeFlo5File(flow, filename):
474
+ with h5py.File(filename, "w") as f:
475
+ f.create_dataset("flow", data=flow, compression="gzip", compression_opts=5)
476
+
477
+ def _read_hdf5_flow(filename):
478
+ flow = np.asarray(h5py.File(filename)['flow'])
479
+ flow[np.isnan(flow)] = np.inf # make invalid values as +inf
480
+ return flow.astype(np.float32)
481
+
482
+ # flow visualization
483
+ RY = 15
484
+ YG = 6
485
+ GC = 4
486
+ CB = 11
487
+ BM = 13
488
+ MR = 6
489
+ UNKNOWN_THRESH = 1e9
490
+
491
+ def colorTest():
492
+ """
493
+ flow_utils.colorTest(): display an example of image showing the color encoding scheme
494
+ """
495
+ import matplotlib.pylab as plt
496
+ truerange = 1
497
+ h,w = 151,151
498
+ trange = truerange*1.04
499
+ s2 = round(h/2)
500
+ x,y = np.meshgrid(range(w),range(h))
501
+ u = x*trange/s2-trange
502
+ v = y*trange/s2-trange
503
+ img = _computeColor(np.concatenate((u[:,:,np.newaxis],v[:,:,np.newaxis]),2)/trange/np.sqrt(2))
504
+ plt.imshow(img)
505
+ plt.axis('off')
506
+ plt.axhline(round(h/2),color='k')
507
+ plt.axvline(round(w/2),color='k')
508
+
509
+ def flowToColor(flow, maxflow=None, maxmaxflow=None, saturate=False):
510
+ """
511
+ flow_utils.flowToColor(flow): return a color code flow field, normalized based on the maximum l2-norm of the flow
512
+ flow_utils.flowToColor(flow,maxflow): return a color code flow field, normalized by maxflow
513
+ ---- PARAMETERS ----
514
+ flow: flow to display of shape (height x width x 2)
515
+ maxflow (default:None): if given, normalize the flow by its value, otherwise by the flow norm
516
+ maxmaxflow (default:None): if given, normalize the flow by the max of its value and the flow norm
517
+ ---- OUTPUT ----
518
+ an np.array of shape (height x width x 3) of type uint8 containing a color code of the flow
519
+ """
520
+ h,w,n = flow.shape
521
+ # check size of flow
522
+ assert n == 2, "flow_utils.flowToColor(flow): flow must have 2 bands"
523
+ # fix unknown flow
524
+ unknown_idx = np.max(np.abs(flow),2)>UNKNOWN_THRESH
525
+ flow[unknown_idx] = 0.0
526
+ # compute max flow if needed
527
+ if maxflow is None:
528
+ maxflow = flowMaxNorm(flow)
529
+ if maxmaxflow is not None:
530
+ maxflow = min(maxmaxflow, maxflow)
531
+ # normalize flow
532
+ eps = np.spacing(1) # minimum positive float value to avoid division by 0
533
+ # compute the flow
534
+ img = _computeColor(flow/(maxflow+eps), saturate=saturate)
535
+ # put black pixels in unknown location
536
+ img[ np.tile( unknown_idx[:,:,np.newaxis],[1,1,3]) ] = 0.0
537
+ return img
538
+
539
+ def flowMaxNorm(flow):
540
+ """
541
+ flow_utils.flowMaxNorm(flow): return the maximum of the l2-norm of the given flow
542
+ ---- PARAMETERS ----
543
+ flow: the flow
544
+
545
+ ---- OUTPUT ----
546
+ a float containing the maximum of the l2-norm of the flow
547
+ """
548
+ return np.max( np.sqrt( np.sum( np.square( flow ) , 2) ) )
549
+
550
+ def _computeColor(flow, saturate=True):
551
+ """
552
+ flow_utils._computeColor(flow): compute color codes for the flow field flow
553
+
554
+ ---- PARAMETERS ----
555
+ flow: np.array of dimension (height x width x 2) containing the flow to display
556
+ ---- OUTPUTS ----
557
+ an np.array of dimension (height x width x 3) containing the color conversion of the flow
558
+ """
559
+ # set nan to 0
560
+ nanidx = np.isnan(flow[:,:,0])
561
+ flow[nanidx] = 0.0
562
+
563
+ # colorwheel
564
+ ncols = RY + YG + GC + CB + BM + MR
565
+ nchans = 3
566
+ colorwheel = np.zeros((ncols,nchans),'uint8')
567
+ col = 0;
568
+ #RY
569
+ colorwheel[:RY,0] = 255
570
+ colorwheel[:RY,1] = [(255*i) // RY for i in range(RY)]
571
+ col += RY
572
+ # YG
573
+ colorwheel[col:col+YG,0] = [255 - (255*i) // YG for i in range(YG)]
574
+ colorwheel[col:col+YG,1] = 255
575
+ col += YG
576
+ # GC
577
+ colorwheel[col:col+GC,1] = 255
578
+ colorwheel[col:col+GC,2] = [(255*i) // GC for i in range(GC)]
579
+ col += GC
580
+ # CB
581
+ colorwheel[col:col+CB,1] = [255 - (255*i) // CB for i in range(CB)]
582
+ colorwheel[col:col+CB,2] = 255
583
+ col += CB
584
+ # BM
585
+ colorwheel[col:col+BM,0] = [(255*i) // BM for i in range(BM)]
586
+ colorwheel[col:col+BM,2] = 255
587
+ col += BM
588
+ # MR
589
+ colorwheel[col:col+MR,0] = 255
590
+ colorwheel[col:col+MR,2] = [255 - (255*i) // MR for i in range(MR)]
591
+
592
+ # compute utility variables
593
+ rad = np.sqrt( np.sum( np.square(flow) , 2) ) # magnitude
594
+ a = np.arctan2( -flow[:,:,1] , -flow[:,:,0]) / np.pi # angle
595
+ fk = (a+1)/2 * (ncols-1) # map [-1,1] to [0,ncols-1]
596
+ k0 = np.floor(fk).astype('int')
597
+ k1 = k0+1
598
+ k1[k1==ncols] = 0
599
+ f = fk-k0
600
+
601
+ if not saturate:
602
+ rad = np.minimum(rad,1)
603
+
604
+ # compute the image
605
+ img = np.zeros( (flow.shape[0],flow.shape[1],nchans), 'uint8' )
606
+ for i in range(nchans):
607
+ tmp = colorwheel[:,i].astype('float')
608
+ col0 = tmp[k0]/255
609
+ col1 = tmp[k1]/255
610
+ col = (1-f)*col0 + f*col1
611
+ idx = (rad <= 1)
612
+ col[idx] = 1-rad[idx]*(1-col[idx]) # increase saturation with radius
613
+ col[~idx] *= 0.75 # out of range
614
+ img[:,:,i] = (255*col*(1-nanidx.astype('float'))).astype('uint8')
615
+
616
+ return img
617
+
618
+ # flow dataset getter
619
+
620
+ def get_train_dataset_flow(dataset_str, augmentor=True, crop_size=None):
621
+ dataset_str = dataset_str.replace('(','Dataset(')
622
+ if augmentor:
623
+ dataset_str = dataset_str.replace(')',', augmentor=True)')
624
+ if crop_size is not None:
625
+ dataset_str = dataset_str.replace(')',', crop_size={:s})'.format(str(crop_size)))
626
+ return eval(dataset_str)
627
+
628
+ def get_test_datasets_flow(dataset_str):
629
+ dataset_str = dataset_str.replace('(','Dataset(')
630
+ return [eval(s) for s in dataset_str.split('+')]
third_party/dust3r/croco/stereoflow/datasets_stereo.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Dataset structure for stereo
6
+ # --------------------------------------------------------
7
+
8
+ import sys, os
9
+ import os.path as osp
10
+ import pickle
11
+ import numpy as np
12
+ from PIL import Image
13
+ import json
14
+ import h5py
15
+ from glob import glob
16
+ import cv2
17
+
18
+ import torch
19
+ from torch.utils import data
20
+
21
+ from .augmentor import StereoAugmentor
22
+
23
+
24
+
25
+ dataset_to_root = {
26
+ 'CREStereo': './data/stereoflow//crenet_stereo_trainset/stereo_trainset/crestereo/',
27
+ 'SceneFlow': './data/stereoflow//SceneFlow/',
28
+ 'ETH3DLowRes': './data/stereoflow/eth3d_lowres/',
29
+ 'Booster': './data/stereoflow/booster_gt/',
30
+ 'Middlebury2021': './data/stereoflow/middlebury/2021/data/',
31
+ 'Middlebury2014': './data/stereoflow/middlebury/2014/',
32
+ 'Middlebury2006': './data/stereoflow/middlebury/2006/',
33
+ 'Middlebury2005': './data/stereoflow/middlebury/2005/train/',
34
+ 'MiddleburyEval3': './data/stereoflow/middlebury/MiddEval3/',
35
+ 'Spring': './data/stereoflow/spring/',
36
+ 'Kitti15': './data/stereoflow/kitti-stereo-2015/',
37
+ 'Kitti12': './data/stereoflow/kitti-stereo-2012/',
38
+ }
39
+ cache_dir = "./data/stereoflow/datasets_stereo_cache/"
40
+
41
+
42
+ in1k_mean = torch.tensor([0.485, 0.456, 0.406]).view(3,1,1)
43
+ in1k_std = torch.tensor([0.229, 0.224, 0.225]).view(3,1,1)
44
+ def img_to_tensor(img):
45
+ img = torch.from_numpy(img).permute(2, 0, 1).float() / 255.
46
+ img = (img-in1k_mean)/in1k_std
47
+ return img
48
+ def disp_to_tensor(disp):
49
+ return torch.from_numpy(disp)[None,:,:]
50
+
51
+ class StereoDataset(data.Dataset):
52
+
53
+ def __init__(self, split, augmentor=False, crop_size=None, totensor=True):
54
+ self.split = split
55
+ if not augmentor: assert crop_size is None
56
+ if crop_size: assert augmentor
57
+ self.crop_size = crop_size
58
+ self.augmentor_str = augmentor
59
+ self.augmentor = StereoAugmentor(crop_size) if augmentor else None
60
+ self.totensor = totensor
61
+ self.rmul = 1 # keep track of rmul
62
+ self.has_constant_resolution = True # whether the dataset has constant resolution or not (=> don't use batch_size>1 at test time)
63
+ self._prepare_data()
64
+ self._load_or_build_cache()
65
+
66
+ def prepare_data(self):
67
+ """
68
+ to be defined for each dataset
69
+ """
70
+ raise NotImplementedError
71
+
72
+ def __len__(self):
73
+ return len(self.pairnames)
74
+
75
+ def __getitem__(self, index):
76
+ pairname = self.pairnames[index]
77
+
78
+ # get filenames
79
+ Limgname = self.pairname_to_Limgname(pairname)
80
+ Rimgname = self.pairname_to_Rimgname(pairname)
81
+ Ldispname = self.pairname_to_Ldispname(pairname) if self.pairname_to_Ldispname is not None else None
82
+
83
+ # load images and disparities
84
+ Limg = _read_img(Limgname)
85
+ Rimg = _read_img(Rimgname)
86
+ disp = self.load_disparity(Ldispname) if Ldispname is not None else None
87
+
88
+ # sanity check
89
+ if disp is not None: assert np.all(disp>0) or self.name=="Spring", (self.name, pairname, Ldispname)
90
+
91
+ # apply augmentations
92
+ if self.augmentor is not None:
93
+ Limg, Rimg, disp = self.augmentor(Limg, Rimg, disp, self.name)
94
+
95
+ if self.totensor:
96
+ Limg = img_to_tensor(Limg)
97
+ Rimg = img_to_tensor(Rimg)
98
+ if disp is None:
99
+ disp = torch.tensor([]) # to allow dataloader batching with default collate_gn
100
+ else:
101
+ disp = disp_to_tensor(disp)
102
+
103
+ return Limg, Rimg, disp, str(pairname)
104
+
105
+ def __rmul__(self, v):
106
+ self.rmul *= v
107
+ self.pairnames = v * self.pairnames
108
+ return self
109
+
110
+ def __str__(self):
111
+ return f'{self.__class__.__name__}_{self.split}'
112
+
113
+ def __repr__(self):
114
+ s = f'{self.__class__.__name__}(split={self.split}, augmentor={self.augmentor_str}, crop_size={str(self.crop_size)}, totensor={self.totensor})'
115
+ if self.rmul==1:
116
+ s+=f'\n\tnum pairs: {len(self.pairnames)}'
117
+ else:
118
+ s+=f'\n\tnum pairs: {len(self.pairnames)} ({len(self.pairnames)//self.rmul}x{self.rmul})'
119
+ return s
120
+
121
+ def _set_root(self):
122
+ self.root = dataset_to_root[self.name]
123
+ assert os.path.isdir(self.root), f"could not find root directory for dataset {self.name}: {self.root}"
124
+
125
+ def _load_or_build_cache(self):
126
+ cache_file = osp.join(cache_dir, self.name+'.pkl')
127
+ if osp.isfile(cache_file):
128
+ with open(cache_file, 'rb') as fid:
129
+ self.pairnames = pickle.load(fid)[self.split]
130
+ else:
131
+ tosave = self._build_cache()
132
+ os.makedirs(cache_dir, exist_ok=True)
133
+ with open(cache_file, 'wb') as fid:
134
+ pickle.dump(tosave, fid)
135
+ self.pairnames = tosave[self.split]
136
+
137
+ class CREStereoDataset(StereoDataset):
138
+
139
+ def _prepare_data(self):
140
+ self.name = 'CREStereo'
141
+ self._set_root()
142
+ assert self.split in ['train']
143
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_left.jpg')
144
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'_right.jpg')
145
+ self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname+'_left.disp.png')
146
+ self.pairname_to_str = lambda pairname: pairname
147
+ self.load_disparity = _read_crestereo_disp
148
+
149
+
150
+ def _build_cache(self):
151
+ allpairs = [s+'/'+f[:-len('_left.jpg')] for s in sorted(os.listdir(self.root)) for f in sorted(os.listdir(self.root+'/'+s)) if f.endswith('_left.jpg')]
152
+ assert len(allpairs)==200000, "incorrect parsing of pairs in CreStereo"
153
+ tosave = {'train': allpairs}
154
+ return tosave
155
+
156
+ class SceneFlowDataset(StereoDataset):
157
+
158
+ def _prepare_data(self):
159
+ self.name = "SceneFlow"
160
+ self._set_root()
161
+ assert self.split in ['train_finalpass','train_cleanpass','train_allpass','test_finalpass','test_cleanpass','test_allpass','test1of100_cleanpass','test1of100_finalpass']
162
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname)
163
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/left/','/right/')
164
+ self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname).replace('/frames_finalpass/','/disparity/').replace('/frames_cleanpass/','/disparity/')[:-4]+'.pfm'
165
+ self.pairname_to_str = lambda pairname: pairname[:-4]
166
+ self.load_disparity = _read_sceneflow_disp
167
+
168
+ def _build_cache(self):
169
+ trainpairs = []
170
+ # driving
171
+ pairs = sorted(glob(self.root+'Driving/frames_finalpass/*/*/*/left/*.png'))
172
+ pairs = list(map(lambda x: x[len(self.root):], pairs))
173
+ assert len(pairs) == 4400, "incorrect parsing of pairs in SceneFlow"
174
+ trainpairs += pairs
175
+ # monkaa
176
+ pairs = sorted(glob(self.root+'Monkaa/frames_finalpass/*/left/*.png'))
177
+ pairs = list(map(lambda x: x[len(self.root):], pairs))
178
+ assert len(pairs) == 8664, "incorrect parsing of pairs in SceneFlow"
179
+ trainpairs += pairs
180
+ # flyingthings
181
+ pairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TRAIN/*/*/left/*.png'))
182
+ pairs = list(map(lambda x: x[len(self.root):], pairs))
183
+ assert len(pairs) == 22390, "incorrect parsing of pairs in SceneFlow"
184
+ trainpairs += pairs
185
+ assert len(trainpairs) == 35454, "incorrect parsing of pairs in SceneFlow"
186
+ testpairs = sorted(glob(self.root+'FlyingThings/frames_finalpass/TEST/*/*/left/*.png'))
187
+ testpairs = list(map(lambda x: x[len(self.root):], testpairs))
188
+ assert len(testpairs) == 4370, "incorrect parsing of pairs in SceneFlow"
189
+ test1of100pairs = testpairs[::100]
190
+ assert len(test1of100pairs) == 44, "incorrect parsing of pairs in SceneFlow"
191
+ # all
192
+ tosave = {'train_finalpass': trainpairs,
193
+ 'train_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), trainpairs)),
194
+ 'test_finalpass': testpairs,
195
+ 'test_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), testpairs)),
196
+ 'test1of100_finalpass': test1of100pairs,
197
+ 'test1of100_cleanpass': list(map(lambda x: x.replace('frames_finalpass','frames_cleanpass'), test1of100pairs)),
198
+ }
199
+ tosave['train_allpass'] = tosave['train_finalpass']+tosave['train_cleanpass']
200
+ tosave['test_allpass'] = tosave['test_finalpass']+tosave['test_cleanpass']
201
+ return tosave
202
+
203
+ class Md21Dataset(StereoDataset):
204
+
205
+ def _prepare_data(self):
206
+ self.name = "Middlebury2021"
207
+ self._set_root()
208
+ assert self.split in ['train','subtrain','subval']
209
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname)
210
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/im0','/im1'))
211
+ self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp0.pfm')
212
+ self.pairname_to_str = lambda pairname: pairname[:-4]
213
+ self.load_disparity = _read_middlebury_disp
214
+
215
+ def _build_cache(self):
216
+ seqs = sorted(os.listdir(self.root))
217
+ trainpairs = []
218
+ for s in seqs:
219
+ #trainpairs += [s+'/im0.png'] # we should remove it, it is included as such in other lightings
220
+ trainpairs += [s+'/ambient/'+b+'/'+a for b in sorted(os.listdir(osp.join(self.root,s,'ambient'))) for a in sorted(os.listdir(osp.join(self.root,s,'ambient',b))) if a.startswith('im0')]
221
+ assert len(trainpairs)==355
222
+ subtrainpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[:-2])]
223
+ subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in seqs[-2:])]
224
+ assert len(subtrainpairs)==335 and len(subvalpairs)==20, "incorrect parsing of pairs in Middlebury 2021"
225
+ tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs}
226
+ return tosave
227
+
228
+ class Md14Dataset(StereoDataset):
229
+
230
+ def _prepare_data(self):
231
+ self.name = "Middlebury2014"
232
+ self._set_root()
233
+ assert self.split in ['train','subtrain','subval']
234
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'im0.png')
235
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname)
236
+ self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'disp0.pfm')
237
+ self.pairname_to_str = lambda pairname: pairname[:-4]
238
+ self.load_disparity = _read_middlebury_disp
239
+ self.has_constant_resolution = False
240
+
241
+ def _build_cache(self):
242
+ seqs = sorted(os.listdir(self.root))
243
+ trainpairs = []
244
+ for s in seqs:
245
+ trainpairs += [s+'/im1.png',s+'/im1E.png',s+'/im1L.png']
246
+ assert len(trainpairs)==138
247
+ valseqs = ['Umbrella-imperfect','Vintage-perfect']
248
+ assert all(s in seqs for s in valseqs)
249
+ subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)]
250
+ subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)]
251
+ assert len(subtrainpairs)==132 and len(subvalpairs)==6, "incorrect parsing of pairs in Middlebury 2014"
252
+ tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs}
253
+ return tosave
254
+
255
+ class Md06Dataset(StereoDataset):
256
+
257
+ def _prepare_data(self):
258
+ self.name = "Middlebury2006"
259
+ self._set_root()
260
+ assert self.split in ['train','subtrain','subval']
261
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname)
262
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png')
263
+ self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png')
264
+ self.load_disparity = _read_middlebury20052006_disp
265
+ self.has_constant_resolution = False
266
+
267
+ def _build_cache(self):
268
+ seqs = sorted(os.listdir(self.root))
269
+ trainpairs = []
270
+ for s in seqs:
271
+ for i in ['Illum1','Illum2','Illum3']:
272
+ for e in ['Exp0','Exp1','Exp2']:
273
+ trainpairs.append(osp.join(s,i,e,'view1.png'))
274
+ assert len(trainpairs)==189
275
+ valseqs = ['Rocks1','Wood2']
276
+ assert all(s in seqs for s in valseqs)
277
+ subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)]
278
+ subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)]
279
+ assert len(subtrainpairs)==171 and len(subvalpairs)==18, "incorrect parsing of pairs in Middlebury 2006"
280
+ tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs}
281
+ return tosave
282
+
283
+ class Md05Dataset(StereoDataset):
284
+
285
+ def _prepare_data(self):
286
+ self.name = "Middlebury2005"
287
+ self._set_root()
288
+ assert self.split in ['train','subtrain','subval']
289
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname)
290
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, osp.dirname(pairname), 'view5.png')
291
+ self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, pairname.split('/')[0], 'disp1.png')
292
+ self.pairname_to_str = lambda pairname: pairname[:-4]
293
+ self.load_disparity = _read_middlebury20052006_disp
294
+
295
+ def _build_cache(self):
296
+ seqs = sorted(os.listdir(self.root))
297
+ trainpairs = []
298
+ for s in seqs:
299
+ for i in ['Illum1','Illum2','Illum3']:
300
+ for e in ['Exp0','Exp1','Exp2']:
301
+ trainpairs.append(osp.join(s,i,e,'view1.png'))
302
+ assert len(trainpairs)==54, "incorrect parsing of pairs in Middlebury 2005"
303
+ valseqs = ['Reindeer']
304
+ assert all(s in seqs for s in valseqs)
305
+ subtrainpairs = [p for p in trainpairs if not any(p.startswith(s+'/') for s in valseqs)]
306
+ subvalpairs = [p for p in trainpairs if any(p.startswith(s+'/') for s in valseqs)]
307
+ assert len(subtrainpairs)==45 and len(subvalpairs)==9, "incorrect parsing of pairs in Middlebury 2005"
308
+ tosave = {'train': trainpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs}
309
+ return tosave
310
+
311
+ class MdEval3Dataset(StereoDataset):
312
+
313
+ def _prepare_data(self):
314
+ self.name = "MiddleburyEval3"
315
+ self._set_root()
316
+ assert self.split in [s+'_'+r for s in ['train','subtrain','subval','test','all'] for r in ['full','half','quarter']]
317
+ if self.split.endswith('_full'):
318
+ self.root = self.root.replace('/MiddEval3','/MiddEval3_F')
319
+ elif self.split.endswith('_half'):
320
+ self.root = self.root.replace('/MiddEval3','/MiddEval3_H')
321
+ else:
322
+ assert self.split.endswith('_quarter')
323
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png')
324
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png')
325
+ self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname, 'disp0GT.pfm')
326
+ self.pairname_to_str = lambda pairname: pairname
327
+ self.load_disparity = _read_middlebury_disp
328
+ # for submission only
329
+ self.submission_methodname = "CroCo-Stereo"
330
+ self.submission_sresolution = 'F' if self.split.endswith('_full') else ('H' if self.split.endswith('_half') else 'Q')
331
+
332
+ def _build_cache(self):
333
+ trainpairs = ['train/'+s for s in sorted(os.listdir(self.root+'train/'))]
334
+ testpairs = ['test/'+s for s in sorted(os.listdir(self.root+'test/'))]
335
+ subvalpairs = trainpairs[-1:]
336
+ subtrainpairs = trainpairs[:-1]
337
+ allpairs = trainpairs+testpairs
338
+ assert len(trainpairs)==15 and len(testpairs)==15 and len(subvalpairs)==1 and len(subtrainpairs)==14 and len(allpairs)==30, "incorrect parsing of pairs in Middlebury Eval v3"
339
+ tosave = {}
340
+ for r in ['full','half','quarter']:
341
+ tosave.update(**{'train_'+r: trainpairs, 'subtrain_'+r: subtrainpairs, 'subval_'+r: subvalpairs, 'test_'+r: testpairs, 'all_'+r: allpairs})
342
+ return tosave
343
+
344
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
345
+ assert prediction.ndim==2
346
+ assert prediction.dtype==np.float32
347
+ outfile = os.path.join(outdir, pairname.split('/')[0].replace('train','training')+self.submission_sresolution, pairname.split('/')[1], 'disp0'+self.submission_methodname+'.pfm')
348
+ os.makedirs( os.path.dirname(outfile), exist_ok=True)
349
+ writePFM(outfile, prediction)
350
+ timefile = os.path.join( os.path.dirname(outfile), "time"+self.submission_methodname+'.txt')
351
+ with open(timefile, 'w') as fid:
352
+ fid.write(str(time))
353
+
354
+ def finalize_submission(self, outdir):
355
+ cmd = f'cd {outdir}/; zip -r "{self.submission_methodname}.zip" .'
356
+ print(cmd)
357
+ os.system(cmd)
358
+ print(f'Done. Submission file at {outdir}/{self.submission_methodname}.zip')
359
+
360
+ class ETH3DLowResDataset(StereoDataset):
361
+
362
+ def _prepare_data(self):
363
+ self.name = "ETH3DLowRes"
364
+ self._set_root()
365
+ assert self.split in ['train','test','subtrain','subval','all']
366
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname, 'im0.png')
367
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname, 'im1.png')
368
+ self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: None if pairname.startswith('test/') else osp.join(self.root, pairname.replace('train/','train_gt/'), 'disp0GT.pfm')
369
+ self.pairname_to_str = lambda pairname: pairname
370
+ self.load_disparity = _read_eth3d_disp
371
+ self.has_constant_resolution = False
372
+
373
+ def _build_cache(self):
374
+ trainpairs = ['train/' + s for s in sorted(os.listdir(self.root+'train/'))]
375
+ testpairs = ['test/' + s for s in sorted(os.listdir(self.root+'test/'))]
376
+ assert len(trainpairs) == 27 and len(testpairs) == 20, "incorrect parsing of pairs in ETH3D Low Res"
377
+ subvalpairs = ['train/delivery_area_3s','train/electro_3l','train/playground_3l']
378
+ assert all(p in trainpairs for p in subvalpairs)
379
+ subtrainpairs = [p for p in trainpairs if not p in subvalpairs]
380
+ assert len(subvalpairs)==3 and len(subtrainpairs)==24, "incorrect parsing of pairs in ETH3D Low Res"
381
+ tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs, 'all': trainpairs+testpairs}
382
+ return tosave
383
+
384
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
385
+ assert prediction.ndim==2
386
+ assert prediction.dtype==np.float32
387
+ outfile = os.path.join(outdir, 'low_res_two_view', pairname.split('/')[1]+'.pfm')
388
+ os.makedirs( os.path.dirname(outfile), exist_ok=True)
389
+ writePFM(outfile, prediction)
390
+ timefile = outfile[:-4]+'.txt'
391
+ with open(timefile, 'w') as fid:
392
+ fid.write('runtime '+str(time))
393
+
394
+ def finalize_submission(self, outdir):
395
+ cmd = f'cd {outdir}/; zip -r "eth3d_low_res_two_view_results.zip" low_res_two_view'
396
+ print(cmd)
397
+ os.system(cmd)
398
+ print(f'Done. Submission file at {outdir}/eth3d_low_res_two_view_results.zip')
399
+
400
+ class BoosterDataset(StereoDataset):
401
+
402
+ def _prepare_data(self):
403
+ self.name = "Booster"
404
+ self._set_root()
405
+ assert self.split in ['train_balanced','test_balanced','subtrain_balanced','subval_balanced'] # we use only the balanced version
406
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname)
407
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname).replace('/camera_00/','/camera_02/')
408
+ self.pairname_to_Ldispname = lambda pairname: osp.join(self.root, osp.dirname(pairname), '../disp_00.npy') # same images with different colors, same gt per sequence
409
+ self.pairname_to_str = lambda pairname: pairname[:-4].replace('/camera_00/','/')
410
+ self.load_disparity = _read_booster_disp
411
+
412
+
413
+ def _build_cache(self):
414
+ trainseqs = sorted(os.listdir(self.root+'train/balanced'))
415
+ trainpairs = ['train/balanced/'+s+'/camera_00/'+imname for s in trainseqs for imname in sorted(os.listdir(self.root+'train/balanced/'+s+'/camera_00/'))]
416
+ testpairs = ['test/balanced/'+s+'/camera_00/'+imname for s in sorted(os.listdir(self.root+'test/balanced')) for imname in sorted(os.listdir(self.root+'test/balanced/'+s+'/camera_00/'))]
417
+ assert len(trainpairs) == 228 and len(testpairs) == 191
418
+ subtrainpairs = [p for p in trainpairs if any(s in p for s in trainseqs[:-2])]
419
+ subvalpairs = [p for p in trainpairs if any(s in p for s in trainseqs[-2:])]
420
+ # warning: if we do validation split, we should split scenes!!!
421
+ tosave = {'train_balanced': trainpairs, 'test_balanced': testpairs, 'subtrain_balanced': subtrainpairs, 'subval_balanced': subvalpairs,}
422
+ return tosave
423
+
424
+ class SpringDataset(StereoDataset):
425
+
426
+ def _prepare_data(self):
427
+ self.name = "Spring"
428
+ self._set_root()
429
+ assert self.split in ['train', 'test', 'subtrain', 'subval']
430
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'.png')
431
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname+'.png').replace('frame_right','<frame_right>').replace('frame_left','frame_right').replace('<frame_right>','frame_left')
432
+ self.pairname_to_Ldispname = lambda pairname: None if pairname.startswith('test') else osp.join(self.root, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right')
433
+ self.pairname_to_str = lambda pairname: pairname
434
+ self.load_disparity = _read_hdf5_disp
435
+
436
+ def _build_cache(self):
437
+ trainseqs = sorted(os.listdir( osp.join(self.root,'train')))
438
+ trainpairs = [osp.join('train',s,'frame_left',f[:-4]) for s in trainseqs for f in sorted(os.listdir(osp.join(self.root,'train',s,'frame_left')))]
439
+ testseqs = sorted(os.listdir( osp.join(self.root,'test')))
440
+ testpairs = [osp.join('test',s,'frame_left',f[:-4]) for s in testseqs for f in sorted(os.listdir(osp.join(self.root,'test',s,'frame_left')))]
441
+ testpairs += [p.replace('frame_left','frame_right') for p in testpairs]
442
+ """maxnorm = {'0001': 32.88, '0002': 228.5, '0004': 298.2, '0005': 142.5, '0006': 113.6, '0007': 27.3, '0008': 554.5, '0009': 155.6, '0010': 126.1, '0011': 87.6, '0012': 303.2, '0013': 24.14, '0014': 82.56, '0015': 98.44, '0016': 156.9, '0017': 28.17, '0018': 21.03, '0020': 178.0, '0021': 58.06, '0022': 354.2, '0023': 8.79, '0024': 97.06, '0025': 55.16, '0026': 91.9, '0027': 156.6, '0030': 200.4, '0032': 58.66, '0033': 373.5, '0036': 149.4, '0037': 5.625, '0038': 37.0, '0039': 12.2, '0041': 453.5, '0043': 457.0, '0044': 379.5, '0045': 161.8, '0047': 105.44} # => let'use 0041"""
443
+ subtrainpairs = [p for p in trainpairs if p.split('/')[1]!='0041']
444
+ subvalpairs = [p for p in trainpairs if p.split('/')[1]=='0041']
445
+ assert len(trainpairs)==5000 and len(testpairs)==2000 and len(subtrainpairs)==4904 and len(subvalpairs)==96, "incorrect parsing of pairs in Spring"
446
+ tosave = {'train': trainpairs, 'test': testpairs, 'subtrain': subtrainpairs, 'subval': subvalpairs}
447
+ return tosave
448
+
449
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
450
+ assert prediction.ndim==2
451
+ assert prediction.dtype==np.float32
452
+ outfile = os.path.join(outdir, pairname+'.dsp5').replace('frame_left','disp1_left').replace('frame_right','disp1_right')
453
+ os.makedirs( os.path.dirname(outfile), exist_ok=True)
454
+ writeDsp5File(prediction, outfile)
455
+
456
+ def finalize_submission(self, outdir):
457
+ assert self.split=='test'
458
+ exe = "{self.root}/disp1_subsampling"
459
+ if os.path.isfile(exe):
460
+ cmd = f'cd "{outdir}/test"; {exe} .'
461
+ print(cmd)
462
+ os.system(cmd)
463
+ else:
464
+ print('Could not find disp1_subsampling executable for submission.')
465
+ print('Please download it and run:')
466
+ print(f'cd "{outdir}/test"; <disp1_subsampling_exe> .')
467
+
468
+ class Kitti12Dataset(StereoDataset):
469
+
470
+ def _prepare_data(self):
471
+ self.name = "Kitti12"
472
+ self._set_root()
473
+ assert self.split in ['train','test']
474
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png')
475
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/colored_1/')+'_10.png')
476
+ self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/colored_0/','/disp_occ/')+'_10.png')
477
+ self.pairname_to_str = lambda pairname: pairname.replace('/colored_0/','/')
478
+ self.load_disparity = _read_kitti_disp
479
+
480
+ def _build_cache(self):
481
+ trainseqs = ["training/colored_0/%06d"%(i) for i in range(194)]
482
+ testseqs = ["testing/colored_0/%06d"%(i) for i in range(195)]
483
+ assert len(trainseqs)==194 and len(testseqs)==195, "incorrect parsing of pairs in Kitti12"
484
+ tosave = {'train': trainseqs, 'test': testseqs}
485
+ return tosave
486
+
487
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
488
+ assert prediction.ndim==2
489
+ assert prediction.dtype==np.float32
490
+ outfile = os.path.join(outdir, pairname.split('/')[-1]+'_10.png')
491
+ os.makedirs( os.path.dirname(outfile), exist_ok=True)
492
+ img = (prediction * 256).astype('uint16')
493
+ Image.fromarray(img).save(outfile)
494
+
495
+ def finalize_submission(self, outdir):
496
+ assert self.split=='test'
497
+ cmd = f'cd {outdir}/; zip -r "kitti12_results.zip" .'
498
+ print(cmd)
499
+ os.system(cmd)
500
+ print(f'Done. Submission file at {outdir}/kitti12_results.zip')
501
+
502
+ class Kitti15Dataset(StereoDataset):
503
+
504
+ def _prepare_data(self):
505
+ self.name = "Kitti15"
506
+ self._set_root()
507
+ assert self.split in ['train','subtrain','subval','test']
508
+ self.pairname_to_Limgname = lambda pairname: osp.join(self.root, pairname+'_10.png')
509
+ self.pairname_to_Rimgname = lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/image_3/')+'_10.png')
510
+ self.pairname_to_Ldispname = None if self.split=='test' else lambda pairname: osp.join(self.root, pairname.replace('/image_2/','/disp_occ_0/')+'_10.png')
511
+ self.pairname_to_str = lambda pairname: pairname.replace('/image_2/','/')
512
+ self.load_disparity = _read_kitti_disp
513
+
514
+ def _build_cache(self):
515
+ trainseqs = ["training/image_2/%06d"%(i) for i in range(200)]
516
+ subtrainseqs = trainseqs[:-5]
517
+ subvalseqs = trainseqs[-5:]
518
+ testseqs = ["testing/image_2/%06d"%(i) for i in range(200)]
519
+ assert len(trainseqs)==200 and len(subtrainseqs)==195 and len(subvalseqs)==5 and len(testseqs)==200, "incorrect parsing of pairs in Kitti15"
520
+ tosave = {'train': trainseqs, 'subtrain': subtrainseqs, 'subval': subvalseqs, 'test': testseqs}
521
+ return tosave
522
+
523
+ def submission_save_pairname(self, pairname, prediction, outdir, time):
524
+ assert prediction.ndim==2
525
+ assert prediction.dtype==np.float32
526
+ outfile = os.path.join(outdir, 'disp_0', pairname.split('/')[-1]+'_10.png')
527
+ os.makedirs( os.path.dirname(outfile), exist_ok=True)
528
+ img = (prediction * 256).astype('uint16')
529
+ Image.fromarray(img).save(outfile)
530
+
531
+ def finalize_submission(self, outdir):
532
+ assert self.split=='test'
533
+ cmd = f'cd {outdir}/; zip -r "kitti15_results.zip" disp_0'
534
+ print(cmd)
535
+ os.system(cmd)
536
+ print(f'Done. Submission file at {outdir}/kitti15_results.zip')
537
+
538
+
539
+ ### auxiliary functions
540
+
541
+ def _read_img(filename):
542
+ # convert to RGB for scene flow finalpass data
543
+ img = np.asarray(Image.open(filename).convert('RGB'))
544
+ return img
545
+
546
+ def _read_booster_disp(filename):
547
+ disp = np.load(filename)
548
+ disp[disp==0.0] = np.inf
549
+ return disp
550
+
551
+ def _read_png_disp(filename, coef=1.0):
552
+ disp = np.asarray(Image.open(filename))
553
+ disp = disp.astype(np.float32) / coef
554
+ disp[disp==0.0] = np.inf
555
+ return disp
556
+
557
+ def _read_pfm_disp(filename):
558
+ disp = np.ascontiguousarray(_read_pfm(filename)[0])
559
+ disp[disp<=0] = np.inf # eg /nfs/data/ffs-3d/datasets/middlebury/2014/Shopvac-imperfect/disp0.pfm
560
+ return disp
561
+
562
+ def _read_npy_disp(filename):
563
+ return np.load(filename)
564
+
565
+ def _read_crestereo_disp(filename): return _read_png_disp(filename, coef=32.0)
566
+ def _read_middlebury20052006_disp(filename): return _read_png_disp(filename, coef=1.0)
567
+ def _read_kitti_disp(filename): return _read_png_disp(filename, coef=256.0)
568
+ _read_sceneflow_disp = _read_pfm_disp
569
+ _read_eth3d_disp = _read_pfm_disp
570
+ _read_middlebury_disp = _read_pfm_disp
571
+ _read_carla_disp = _read_pfm_disp
572
+ _read_tartanair_disp = _read_npy_disp
573
+
574
+ def _read_hdf5_disp(filename):
575
+ disp = np.asarray(h5py.File(filename)['disparity'])
576
+ disp[np.isnan(disp)] = np.inf # make invalid values as +inf
577
+ #disp[disp==0.0] = np.inf # make invalid values as +inf
578
+ return disp.astype(np.float32)
579
+
580
+ import re
581
+ def _read_pfm(file):
582
+ file = open(file, 'rb')
583
+
584
+ color = None
585
+ width = None
586
+ height = None
587
+ scale = None
588
+ endian = None
589
+
590
+ header = file.readline().rstrip()
591
+ if header.decode("ascii") == 'PF':
592
+ color = True
593
+ elif header.decode("ascii") == 'Pf':
594
+ color = False
595
+ else:
596
+ raise Exception('Not a PFM file.')
597
+
598
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii"))
599
+ if dim_match:
600
+ width, height = list(map(int, dim_match.groups()))
601
+ else:
602
+ raise Exception('Malformed PFM header.')
603
+
604
+ scale = float(file.readline().decode("ascii").rstrip())
605
+ if scale < 0: # little-endian
606
+ endian = '<'
607
+ scale = -scale
608
+ else:
609
+ endian = '>' # big-endian
610
+
611
+ data = np.fromfile(file, endian + 'f')
612
+ shape = (height, width, 3) if color else (height, width)
613
+
614
+ data = np.reshape(data, shape)
615
+ data = np.flipud(data)
616
+ return data, scale
617
+
618
+ def writePFM(file, image, scale=1):
619
+ file = open(file, 'wb')
620
+
621
+ color = None
622
+
623
+ if image.dtype.name != 'float32':
624
+ raise Exception('Image dtype must be float32.')
625
+
626
+ image = np.flipud(image)
627
+
628
+ if len(image.shape) == 3 and image.shape[2] == 3: # color image
629
+ color = True
630
+ elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
631
+ color = False
632
+ else:
633
+ raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.')
634
+
635
+ file.write('PF\n' if color else 'Pf\n'.encode())
636
+ file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0]))
637
+
638
+ endian = image.dtype.byteorder
639
+
640
+ if endian == '<' or endian == '=' and sys.byteorder == 'little':
641
+ scale = -scale
642
+
643
+ file.write('%f\n'.encode() % scale)
644
+
645
+ image.tofile(file)
646
+
647
+ def writeDsp5File(disp, filename):
648
+ with h5py.File(filename, "w") as f:
649
+ f.create_dataset("disparity", data=disp, compression="gzip", compression_opts=5)
650
+
651
+
652
+ # disp visualization
653
+
654
+ def vis_disparity(disp, m=None, M=None):
655
+ if m is None: m = disp.min()
656
+ if M is None: M = disp.max()
657
+ disp_vis = (disp - m) / (M-m) * 255.0
658
+ disp_vis = disp_vis.astype("uint8")
659
+ disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
660
+ return disp_vis
661
+
662
+ # dataset getter
663
+
664
+ def get_train_dataset_stereo(dataset_str, augmentor=True, crop_size=None):
665
+ dataset_str = dataset_str.replace('(','Dataset(')
666
+ if augmentor:
667
+ dataset_str = dataset_str.replace(')',', augmentor=True)')
668
+ if crop_size is not None:
669
+ dataset_str = dataset_str.replace(')',', crop_size={:s})'.format(str(crop_size)))
670
+ return eval(dataset_str)
671
+
672
+ def get_test_datasets_stereo(dataset_str):
673
+ dataset_str = dataset_str.replace('(','Dataset(')
674
+ return [eval(s) for s in dataset_str.split('+')]
third_party/dust3r/croco/stereoflow/download_model.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ model=$1
5
+ outfile="stereoflow_models/${model}"
6
+ if [[ ! -f $outfile ]]
7
+ then
8
+ mkdir -p stereoflow_models/;
9
+ wget https://download.europe.naverlabs.com/ComputerVision/CroCo/StereoFlow_models/$1 -P stereoflow_models/;
10
+ else
11
+ echo "Model ${model} already downloaded in ${outfile}."
12
+ fi
third_party/dust3r/croco/stereoflow/engine.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Main function for training one epoch or testing
6
+ # --------------------------------------------------------
7
+
8
+ import math
9
+ import sys
10
+ from typing import Iterable
11
+ import numpy as np
12
+ import torch
13
+ import torchvision
14
+
15
+ from utils import misc as misc
16
+
17
+
18
+ def split_prediction_conf(predictions, with_conf=False):
19
+ if not with_conf:
20
+ return predictions, None
21
+ conf = predictions[:,-1:,:,:]
22
+ predictions = predictions[:,:-1,:,:]
23
+ return predictions, conf
24
+
25
+ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, metrics: torch.nn.Module,
26
+ data_loader: Iterable, optimizer: torch.optim.Optimizer,
27
+ device: torch.device, epoch: int, loss_scaler,
28
+ log_writer=None, print_freq = 20,
29
+ args=None):
30
+ model.train(True)
31
+ metric_logger = misc.MetricLogger(delimiter=" ")
32
+ metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
33
+ header = 'Epoch: [{}]'.format(epoch)
34
+
35
+ accum_iter = args.accum_iter
36
+
37
+ optimizer.zero_grad()
38
+
39
+ details = {}
40
+
41
+ if log_writer is not None:
42
+ print('log_dir: {}'.format(log_writer.log_dir))
43
+
44
+ if args.img_per_epoch:
45
+ iter_per_epoch = args.img_per_epoch // args.batch_size + int(args.img_per_epoch % args.batch_size > 0)
46
+ assert len(data_loader) >= iter_per_epoch, 'Dataset is too small for so many iterations'
47
+ len_data_loader = iter_per_epoch
48
+ else:
49
+ len_data_loader, iter_per_epoch = len(data_loader), None
50
+
51
+ for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_logger.log_every(data_loader, print_freq, header, max_iter=iter_per_epoch)):
52
+
53
+ image1 = image1.to(device, non_blocking=True)
54
+ image2 = image2.to(device, non_blocking=True)
55
+ gt = gt.to(device, non_blocking=True)
56
+
57
+ # we use a per iteration (instead of per epoch) lr scheduler
58
+ if data_iter_step % accum_iter == 0:
59
+ misc.adjust_learning_rate(optimizer, data_iter_step / len_data_loader + epoch, args)
60
+
61
+ with torch.cuda.amp.autocast(enabled=bool(args.amp)):
62
+ prediction = model(image1, image2)
63
+ prediction, conf = split_prediction_conf(prediction, criterion.with_conf)
64
+ batch_metrics = metrics(prediction.detach(), gt)
65
+ loss = criterion(prediction, gt) if conf is None else criterion(prediction, gt, conf)
66
+
67
+ loss_value = loss.item()
68
+ if not math.isfinite(loss_value):
69
+ print("Loss is {}, stopping training".format(loss_value))
70
+ sys.exit(1)
71
+
72
+ loss /= accum_iter
73
+ loss_scaler(loss, optimizer, parameters=model.parameters(),
74
+ update_grad=(data_iter_step + 1) % accum_iter == 0)
75
+ if (data_iter_step + 1) % accum_iter == 0:
76
+ optimizer.zero_grad()
77
+
78
+ torch.cuda.synchronize()
79
+
80
+ metric_logger.update(loss=loss_value)
81
+ for k,v in batch_metrics.items():
82
+ metric_logger.update(**{k: v.item()})
83
+ lr = optimizer.param_groups[0]["lr"]
84
+ metric_logger.update(lr=lr)
85
+
86
+ #if args.dsitributed: loss_value_reduce = misc.all_reduce_mean(loss_value)
87
+ time_to_log = ((data_iter_step + 1) % (args.tboard_log_step * accum_iter) == 0 or data_iter_step == len_data_loader-1)
88
+ loss_value_reduce = misc.all_reduce_mean(loss_value)
89
+ if log_writer is not None and time_to_log:
90
+ epoch_1000x = int((data_iter_step / len_data_loader + epoch) * 1000)
91
+ # We use epoch_1000x as the x-axis in tensorboard. This calibrates different curves when batch size changes.
92
+ log_writer.add_scalar('train/loss', loss_value_reduce, epoch_1000x)
93
+ log_writer.add_scalar('lr', lr, epoch_1000x)
94
+ for k,v in batch_metrics.items():
95
+ log_writer.add_scalar('train/'+k, v.item(), epoch_1000x)
96
+
97
+ # gather the stats from all processes
98
+ #if args.distributed: metric_logger.synchronize_between_processes()
99
+ print("Averaged stats:", metric_logger)
100
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
101
+
102
+
103
+ @torch.no_grad()
104
+ def validate_one_epoch(model: torch.nn.Module,
105
+ criterion: torch.nn.Module,
106
+ metrics: torch.nn.Module,
107
+ data_loaders: list[Iterable],
108
+ device: torch.device,
109
+ epoch: int,
110
+ log_writer=None,
111
+ args=None):
112
+
113
+ model.eval()
114
+ metric_loggers = []
115
+ header = 'Epoch: [{}]'.format(epoch)
116
+ print_freq = 20
117
+
118
+ conf_mode = args.tile_conf_mode
119
+ crop = args.crop
120
+
121
+ if log_writer is not None:
122
+ print('log_dir: {}'.format(log_writer.log_dir))
123
+
124
+ results = {}
125
+ dnames = []
126
+ image1, image2, gt, prediction = None, None, None, None
127
+ for didx, data_loader in enumerate(data_loaders):
128
+ dname = str(data_loader.dataset)
129
+ dnames.append(dname)
130
+ metric_loggers.append(misc.MetricLogger(delimiter=" "))
131
+ for data_iter_step, (image1, image2, gt, pairname) in enumerate(metric_loggers[didx].log_every(data_loader, print_freq, header)):
132
+ image1 = image1.to(device, non_blocking=True)
133
+ image2 = image2.to(device, non_blocking=True)
134
+ gt = gt.to(device, non_blocking=True)
135
+ if dname.startswith('Spring'):
136
+ assert gt.size(2)==image1.size(2)*2 and gt.size(3)==image1.size(3)*2
137
+ gt = (gt[:,:,0::2,0::2] + gt[:,:,0::2,1::2] + gt[:,:,1::2,0::2] + gt[:,:,1::2,1::2] ) / 4.0 # we approximate the gt based on the 2x upsampled ones
138
+
139
+ with torch.inference_mode():
140
+ prediction, tiled_loss, c = tiled_pred(model, criterion, image1, image2, gt, conf_mode=conf_mode, overlap=args.val_overlap, crop=crop, with_conf=criterion.with_conf)
141
+ batch_metrics = metrics(prediction.detach(), gt)
142
+ loss = criterion(prediction.detach(), gt) if not criterion.with_conf else criterion(prediction.detach(), gt, c)
143
+ loss_value = loss.item()
144
+ metric_loggers[didx].update(loss_tiled=tiled_loss.item())
145
+ metric_loggers[didx].update(**{f'loss': loss_value})
146
+ for k,v in batch_metrics.items():
147
+ metric_loggers[didx].update(**{dname+'_' + k: v.item()})
148
+
149
+ results = {k: meter.global_avg for ml in metric_loggers for k, meter in ml.meters.items()}
150
+ if len(dnames)>1:
151
+ for k in batch_metrics.keys():
152
+ results['AVG_'+k] = sum(results[dname+'_'+k] for dname in dnames) / len(dnames)
153
+
154
+ if log_writer is not None :
155
+ epoch_1000x = int((1 + epoch) * 1000)
156
+ for k,v in results.items():
157
+ log_writer.add_scalar('val/'+k, v, epoch_1000x)
158
+
159
+ print("Averaged stats:", results)
160
+ return results
161
+
162
+ import torch.nn.functional as F
163
+ def _resize_img(img, new_size):
164
+ return F.interpolate(img, size=new_size, mode='bicubic', align_corners=False)
165
+ def _resize_stereo_or_flow(data, new_size):
166
+ assert data.ndim==4
167
+ assert data.size(1) in [1,2]
168
+ scale_x = new_size[1]/float(data.size(3))
169
+ out = F.interpolate(data, size=new_size, mode='bicubic', align_corners=False)
170
+ out[:,0,:,:] *= scale_x
171
+ if out.size(1)==2:
172
+ scale_y = new_size[0]/float(data.size(2))
173
+ out[:,1,:,:] *= scale_y
174
+ print(scale_x, new_size, data.shape)
175
+ return out
176
+
177
+
178
+ @torch.no_grad()
179
+ def tiled_pred(model, criterion, img1, img2, gt,
180
+ overlap=0.5, bad_crop_thr=0.05,
181
+ downscale=False, crop=512, ret='loss',
182
+ conf_mode='conf_expsigmoid_10_5', with_conf=False,
183
+ return_time=False):
184
+
185
+ # for each image, we are going to run inference on many overlapping patches
186
+ # then, all predictions will be weighted-averaged
187
+ if gt is not None:
188
+ B, C, H, W = gt.shape
189
+ else:
190
+ B, _, H, W = img1.shape
191
+ C = model.head.num_channels-int(with_conf)
192
+ win_height, win_width = crop[0], crop[1]
193
+
194
+ # upscale to be larger than the crop
195
+ do_change_scale = H<win_height or W<win_width
196
+ if do_change_scale:
197
+ upscale_factor = max(win_width/W, win_height/W)
198
+ original_size = (H,W)
199
+ new_size = (round(H*upscale_factor),round(W*upscale_factor))
200
+ img1 = _resize_img(img1, new_size)
201
+ img2 = _resize_img(img2, new_size)
202
+ # resize gt just for the computation of tiled losses
203
+ if gt is not None: gt = _resize_stereo_or_flow(gt, new_size)
204
+ H,W = img1.shape[2:4]
205
+
206
+ if conf_mode.startswith('conf_expsigmoid_'): # conf_expsigmoid_30_10
207
+ beta, betasigmoid = map(float, conf_mode[len('conf_expsigmoid_'):].split('_'))
208
+ elif conf_mode.startswith('conf_expbeta'): # conf_expbeta3
209
+ beta = float(conf_mode[len('conf_expbeta'):])
210
+ else:
211
+ raise NotImplementedError(f"conf_mode {conf_mode} is not implemented")
212
+
213
+ def crop_generator():
214
+ for sy in _overlapping(H, win_height, overlap):
215
+ for sx in _overlapping(W, win_width, overlap):
216
+ yield sy, sx, sy, sx, True
217
+
218
+ # keep track of weighted sum of prediction*weights and weights
219
+ accu_pred = img1.new_zeros((B, C, H, W)) # accumulate the weighted sum of predictions
220
+ accu_conf = img1.new_zeros((B, H, W)) + 1e-16 # accumulate the weights
221
+ accu_c = img1.new_zeros((B, H, W)) # accumulate the weighted sum of confidences ; not so useful except for computing some losses
222
+
223
+ tiled_losses = []
224
+
225
+ if return_time:
226
+ start = torch.cuda.Event(enable_timing=True)
227
+ end = torch.cuda.Event(enable_timing=True)
228
+ start.record()
229
+
230
+ for sy1, sx1, sy2, sx2, aligned in crop_generator():
231
+ # compute optical flow there
232
+ pred = model(_crop(img1,sy1,sx1), _crop(img2,sy2,sx2))
233
+ pred, predconf = split_prediction_conf(pred, with_conf=with_conf)
234
+
235
+ if gt is not None: gtcrop = _crop(gt,sy1,sx1)
236
+ if criterion is not None and gt is not None:
237
+ tiled_losses.append( criterion(pred, gtcrop).item() if predconf is None else criterion(pred, gtcrop, predconf).item() )
238
+
239
+ if conf_mode.startswith('conf_expsigmoid_'):
240
+ conf = torch.exp(- beta * 2 * (torch.sigmoid(predconf / betasigmoid) - 0.5)).view(B,win_height,win_width)
241
+ elif conf_mode.startswith('conf_expbeta'):
242
+ conf = torch.exp(- beta * predconf).view(B,win_height,win_width)
243
+ else:
244
+ raise NotImplementedError
245
+
246
+ accu_pred[...,sy1,sx1] += pred * conf[:,None,:,:]
247
+ accu_conf[...,sy1,sx1] += conf
248
+ accu_c[...,sy1,sx1] += predconf.view(B,win_height,win_width) * conf
249
+
250
+ pred = accu_pred / accu_conf[:, None,:,:]
251
+ c = accu_c / accu_conf
252
+ assert not torch.any(torch.isnan(pred))
253
+
254
+ if return_time:
255
+ end.record()
256
+ torch.cuda.synchronize()
257
+ time = start.elapsed_time(end)/1000.0 # this was in milliseconds
258
+
259
+ if do_change_scale:
260
+ pred = _resize_stereo_or_flow(pred, original_size)
261
+
262
+ if return_time:
263
+ return pred, torch.mean(torch.tensor(tiled_losses)), c, time
264
+ return pred, torch.mean(torch.tensor(tiled_losses)), c
265
+
266
+
267
+ def _overlapping(total, window, overlap=0.5):
268
+ assert total >= window and 0 <= overlap < 1, (total, window, overlap)
269
+ num_windows = 1 + int(np.ceil( (total - window) / ((1-overlap) * window) ))
270
+ offsets = np.linspace(0, total-window, num_windows).round().astype(int)
271
+ yield from (slice(x, x+window) for x in offsets)
272
+
273
+ def _crop(img, sy, sx):
274
+ B, THREE, H, W = img.shape
275
+ if 0 <= sy.start and sy.stop <= H and 0 <= sx.start and sx.stop <= W:
276
+ return img[:,:,sy,sx]
277
+ l, r = max(0,-sx.start), max(0,sx.stop-W)
278
+ t, b = max(0,-sy.start), max(0,sy.stop-H)
279
+ img = torch.nn.functional.pad(img, (l,r,t,b), mode='constant')
280
+ return img[:, :, slice(sy.start+t,sy.stop+t), slice(sx.start+l,sx.stop+l)]
third_party/dust3r/croco/stereoflow/test.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Main test function
6
+ # --------------------------------------------------------
7
+
8
+ import os
9
+ import argparse
10
+ import pickle
11
+ from PIL import Image
12
+ import numpy as np
13
+ from tqdm import tqdm
14
+
15
+ import torch
16
+ from torch.utils.data import DataLoader
17
+
18
+ import utils.misc as misc
19
+ from models.croco_downstream import CroCoDownstreamBinocular
20
+ from models.head_downstream import PixelwiseTaskWithDPT
21
+
22
+ from stereoflow.criterion import *
23
+ from stereoflow.datasets_stereo import get_test_datasets_stereo
24
+ from stereoflow.datasets_flow import get_test_datasets_flow
25
+ from stereoflow.engine import tiled_pred
26
+
27
+ from stereoflow.datasets_stereo import vis_disparity
28
+ from stereoflow.datasets_flow import flowToColor
29
+
30
+ def get_args_parser():
31
+ parser = argparse.ArgumentParser('Test CroCo models on stereo/flow', add_help=False)
32
+ # important argument
33
+ parser.add_argument('--model', required=True, type=str, help='Path to the model to evaluate')
34
+ parser.add_argument('--dataset', required=True, type=str, help="test dataset (there can be multiple dataset separated by a +)")
35
+ # tiling
36
+ parser.add_argument('--tile_conf_mode', type=str, default='', help='Weights for the tiling aggregation based on confidence (empty means use the formula from the loaded checkpoint')
37
+ parser.add_argument('--tile_overlap', type=float, default=0.7, help='overlap between tiles')
38
+ # save (it will automatically go to <model_path>_<dataset_str>/<tile_str>_<save>)
39
+ parser.add_argument('--save', type=str, nargs='+', default=[],
40
+ help='what to save: \
41
+ metrics (pickle file), \
42
+ pred (raw prediction save as torch tensor), \
43
+ visu (visualization in png of each prediction), \
44
+ err10 (visualization in png of the error clamp at 10 for each prediction), \
45
+ submission (submission file)')
46
+ # other (no impact)
47
+ parser.add_argument('--num_workers', default=4, type=int)
48
+ return parser
49
+
50
+
51
+ def _load_model_and_criterion(model_path, do_load_metrics, device):
52
+ print('loading model from', model_path)
53
+ assert os.path.isfile(model_path)
54
+ ckpt = torch.load(model_path, 'cpu')
55
+
56
+ ckpt_args = ckpt['args']
57
+ task = ckpt_args.task
58
+ tile_conf_mode = ckpt_args.tile_conf_mode
59
+ num_channels = {'stereo': 1, 'flow': 2}[task]
60
+ with_conf = eval(ckpt_args.criterion).with_conf
61
+ if with_conf: num_channels += 1
62
+ print('head: PixelwiseTaskWithDPT()')
63
+ head = PixelwiseTaskWithDPT()
64
+ head.num_channels = num_channels
65
+ print('croco_args:', ckpt_args.croco_args)
66
+ model = CroCoDownstreamBinocular(head, **ckpt_args.croco_args)
67
+ msg = model.load_state_dict(ckpt['model'], strict=True)
68
+ model.eval()
69
+ model = model.to(device)
70
+
71
+ if do_load_metrics:
72
+ if task=='stereo':
73
+ metrics = StereoDatasetMetrics().to(device)
74
+ else:
75
+ metrics = FlowDatasetMetrics().to(device)
76
+ else:
77
+ metrics = None
78
+
79
+ return model, metrics, ckpt_args.crop, with_conf, task, tile_conf_mode
80
+
81
+
82
+ def _save_batch(pred, gt, pairnames, dataset, task, save, outdir, time, submission_dir=None):
83
+
84
+ for i in range(len(pairnames)):
85
+
86
+ pairname = eval(pairnames[i]) if pairnames[i].startswith('(') else pairnames[i] # unbatch pairname
87
+ fname = os.path.join(outdir, dataset.pairname_to_str(pairname))
88
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
89
+
90
+ predi = pred[i,...]
91
+ if gt is not None: gti = gt[i,...]
92
+
93
+ if 'pred' in save:
94
+ torch.save(predi.squeeze(0).cpu(), fname+'_pred.pth')
95
+
96
+ if 'visu' in save:
97
+ if task=='stereo':
98
+ disparity = predi.permute((1,2,0)).squeeze(2).cpu().numpy()
99
+ m,M = None
100
+ if gt is not None:
101
+ mask = torch.isfinite(gti)
102
+ m = gt[mask].min()
103
+ M = gt[mask].max()
104
+ img_disparity = vis_disparity(disparity, m=m, M=M)
105
+ Image.fromarray(img_disparity).save(fname+'_pred.png')
106
+ else:
107
+ # normalize flowToColor according to the maxnorm of gt (or prediction if not available)
108
+ flowNorm = torch.sqrt(torch.sum( (gti if gt is not None else predi)**2, dim=0)).max().item()
109
+ imgflow = flowToColor(predi.permute((1,2,0)).cpu().numpy(), maxflow=flowNorm)
110
+ Image.fromarray(imgflow).save(fname+'_pred.png')
111
+
112
+ if 'err10' in save:
113
+ assert gt is not None
114
+ L2err = torch.sqrt(torch.sum( (gti-predi)**2, dim=0))
115
+ valid = torch.isfinite(gti[0,:,:])
116
+ L2err[~valid] = 0.0
117
+ L2err = torch.clamp(L2err, max=10.0)
118
+ red = (L2err*255.0/10.0).to(dtype=torch.uint8)[:,:,None]
119
+ zer = torch.zeros_like(red)
120
+ imgerr = torch.cat( (red,zer,zer), dim=2).cpu().numpy()
121
+ Image.fromarray(imgerr).save(fname+'_err10.png')
122
+
123
+ if 'submission' in save:
124
+ assert submission_dir is not None
125
+ predi_np = predi.permute(1,2,0).squeeze(2).cpu().numpy() # transform into HxWx2 for flow or HxW for stereo
126
+ dataset.submission_save_pairname(pairname, predi_np, submission_dir, time)
127
+
128
+ def main(args):
129
+
130
+ # load the pretrained model and metrics
131
+ device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
132
+ model, metrics, cropsize, with_conf, task, tile_conf_mode = _load_model_and_criterion(args.model, 'metrics' in args.save, device)
133
+ if args.tile_conf_mode=='': args.tile_conf_mode = tile_conf_mode
134
+
135
+ # load the datasets
136
+ datasets = (get_test_datasets_stereo if task=='stereo' else get_test_datasets_flow)(args.dataset)
137
+ dataloaders = [DataLoader(dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) for dataset in datasets]
138
+
139
+ # run
140
+ for i,dataloader in enumerate(dataloaders):
141
+ dataset = datasets[i]
142
+ dstr = args.dataset.split('+')[i]
143
+
144
+ outdir = args.model+'_'+misc.filename(dstr)
145
+ if 'metrics' in args.save and len(args.save)==1:
146
+ fname = os.path.join(outdir, f'conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}.pkl')
147
+ if os.path.isfile(fname) and len(args.save)==1:
148
+ print(' metrics already compute in '+fname)
149
+ with open(fname, 'rb') as fid:
150
+ results = pickle.load(fid)
151
+ for k,v in results.items():
152
+ print('{:s}: {:.3f}'.format(k, v))
153
+ continue
154
+
155
+ if 'submission' in args.save:
156
+ dirname = f'submission_conf_{args.tile_conf_mode}_overlap_{args.tile_overlap}'
157
+ submission_dir = os.path.join(outdir, dirname)
158
+ else:
159
+ submission_dir = None
160
+
161
+ print('')
162
+ print('saving {:s} in {:s}'.format('+'.join(args.save), outdir))
163
+ print(repr(dataset))
164
+
165
+ if metrics is not None:
166
+ metrics.reset()
167
+
168
+ for data_iter_step, (image1, image2, gt, pairnames) in enumerate(tqdm(dataloader)):
169
+
170
+ do_flip = (task=='stereo' and dstr.startswith('Spring') and any("right" in p for p in pairnames)) # we flip the images and will flip the prediction after as we assume img1 is on the left
171
+
172
+ image1 = image1.to(device, non_blocking=True)
173
+ image2 = image2.to(device, non_blocking=True)
174
+ gt = gt.to(device, non_blocking=True) if gt.numel()>0 else None # special case for test time
175
+ if do_flip:
176
+ assert all("right" in p for p in pairnames)
177
+ image1 = image1.flip(dims=[3]) # this is already the right frame, let's flip it
178
+ image2 = image2.flip(dims=[3])
179
+ gt = gt # that is ok
180
+
181
+ with torch.inference_mode():
182
+ pred, _, _, time = tiled_pred(model, None, image1, image2, None if dataset.name=='Spring' else gt, conf_mode=args.tile_conf_mode, overlap=args.tile_overlap, crop=cropsize, with_conf=with_conf, return_time=True)
183
+
184
+ if do_flip:
185
+ pred = pred.flip(dims=[3])
186
+
187
+ if metrics is not None:
188
+ metrics.add_batch(pred, gt)
189
+
190
+ if any(k in args.save for k in ['pred','visu','err10','submission']):
191
+ _save_batch(pred, gt, pairnames, dataset, task, args.save, outdir, time, submission_dir=submission_dir)
192
+
193
+
194
+ # print
195
+ if metrics is not None:
196
+ results = metrics.get_results()
197
+ for k,v in results.items():
198
+ print('{:s}: {:.3f}'.format(k, v))
199
+
200
+ # save if needed
201
+ if 'metrics' in args.save:
202
+ os.makedirs(os.path.dirname(fname), exist_ok=True)
203
+ with open(fname, 'wb') as fid:
204
+ pickle.dump(results, fid)
205
+ print('metrics saved in', fname)
206
+
207
+ # finalize submission if needed
208
+ if 'submission' in args.save:
209
+ dataset.finalize_submission(submission_dir)
210
+
211
+
212
+
213
+ if __name__ == '__main__':
214
+ args = get_args_parser()
215
+ args = args.parse_args()
216
+ main(args)
third_party/dust3r/croco/stereoflow/train.py ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+
4
+ # --------------------------------------------------------
5
+ # Main training function
6
+ # --------------------------------------------------------
7
+
8
+ import argparse
9
+ import datetime
10
+ import json
11
+ import numpy as np
12
+ import os
13
+ import sys
14
+ import time
15
+
16
+ import torch
17
+ import torch.distributed as dist
18
+ import torch.backends.cudnn as cudnn
19
+ from torch.utils.tensorboard import SummaryWriter
20
+ import torchvision.transforms as transforms
21
+ import torchvision.datasets as datasets
22
+ from torch.utils.data import DataLoader
23
+
24
+ import utils
25
+ import utils.misc as misc
26
+ from utils.misc import NativeScalerWithGradNormCount as NativeScaler
27
+ from models.croco_downstream import CroCoDownstreamBinocular, croco_args_from_ckpt
28
+ from models.pos_embed import interpolate_pos_embed
29
+ from models.head_downstream import PixelwiseTaskWithDPT
30
+
31
+ from stereoflow.datasets_stereo import get_train_dataset_stereo, get_test_datasets_stereo
32
+ from stereoflow.datasets_flow import get_train_dataset_flow, get_test_datasets_flow
33
+ from stereoflow.engine import train_one_epoch, validate_one_epoch
34
+ from stereoflow.criterion import *
35
+
36
+
37
+ def get_args_parser():
38
+ # prepare subparsers
39
+ parser = argparse.ArgumentParser('Finetuning CroCo models on stereo or flow', add_help=False)
40
+ subparsers = parser.add_subparsers(title="Task (stereo or flow)", dest="task", required=True)
41
+ parser_stereo = subparsers.add_parser('stereo', help='Training stereo model')
42
+ parser_flow = subparsers.add_parser('flow', help='Training flow model')
43
+ def add_arg(name_or_flags, default=None, default_stereo=None, default_flow=None, **kwargs):
44
+ if default is not None: assert default_stereo is None and default_flow is None, "setting default makes default_stereo and default_flow disabled"
45
+ parser_stereo.add_argument(name_or_flags, default=default if default is not None else default_stereo, **kwargs)
46
+ parser_flow.add_argument(name_or_flags, default=default if default is not None else default_flow, **kwargs)
47
+ # output dir
48
+ add_arg('--output_dir', required=True, type=str, help='path where to save, if empty, automatically created')
49
+ # model
50
+ add_arg('--crop', type=int, nargs = '+', default_stereo=[352, 704], default_flow=[320, 384], help = "size of the random image crops used during training.")
51
+ add_arg('--pretrained', required=True, type=str, help="Load pretrained model (required as croco arguments come from there)")
52
+ # criterion
53
+ add_arg('--criterion', default_stereo='LaplacianLossBounded2()', default_flow='LaplacianLossBounded()', type=str, help='string to evaluate to get criterion')
54
+ add_arg('--bestmetric', default_stereo='avgerr', default_flow='EPE', type=str)
55
+ # dataset
56
+ add_arg('--dataset', type=str, required=True, help="training set")
57
+ # training
58
+ add_arg('--seed', default=0, type=int, help='seed')
59
+ add_arg('--batch_size', default_stereo=6, default_flow=8, type=int, help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
60
+ add_arg('--epochs', default=32, type=int, help='number of training epochs')
61
+ add_arg('--img_per_epoch', type=int, default=None, help='Fix the number of images seen in an epoch (None means use all training pairs)')
62
+ add_arg('--accum_iter', default=1, type=int, help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')
63
+ add_arg('--weight_decay', type=float, default=0.05, help='weight decay (default: 0.05)')
64
+ add_arg('--lr', type=float, default_stereo=3e-5, default_flow=2e-5, metavar='LR', help='learning rate (absolute lr)')
65
+ add_arg('--min_lr', type=float, default=0., metavar='LR', help='lower lr bound for cyclic schedulers that hit 0')
66
+ add_arg('--warmup_epochs', type=int, default=1, metavar='N', help='epochs to warmup LR')
67
+ add_arg('--optimizer', default='AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95))', type=str,
68
+ help="Optimizer from torch.optim [ default: AdamW(param_groups, lr=args.lr, betas=(0.9, 0.95)) ]")
69
+ add_arg('--amp', default=0, type=int, choices=[0,1], help='enable automatic mixed precision training')
70
+ # validation
71
+ add_arg('--val_dataset', type=str, default='', help="Validation sets, multiple separated by + (empty string means that no validation is performed)")
72
+ add_arg('--tile_conf_mode', type=str, default_stereo='conf_expsigmoid_15_3', default_flow='conf_expsigmoid_10_5', help='Weights for tile aggregation')
73
+ add_arg('--val_overlap', default=0.7, type=float, help='Overlap value for the tiling')
74
+ # others
75
+ add_arg('--num_workers', default=8, type=int)
76
+ add_arg('--eval_every', type=int, default=1, help='Val loss evaluation frequency')
77
+ add_arg('--save_every', type=int, default=1, help='Save checkpoint frequency')
78
+ add_arg('--start_from', type=str, default=None, help='Start training using weights from an other model (eg for finetuning)')
79
+ add_arg('--tboard_log_step', type=int, default=100, help='Log to tboard every so many steps')
80
+ add_arg('--dist_url', default='env://', help='url used to set up distributed training')
81
+
82
+ return parser
83
+
84
+
85
+ def main(args):
86
+ misc.init_distributed_mode(args)
87
+ global_rank = misc.get_rank()
88
+ num_tasks = misc.get_world_size()
89
+
90
+ assert os.path.isfile(args.pretrained)
91
+ print("output_dir: "+args.output_dir)
92
+ os.makedirs(args.output_dir, exist_ok=True)
93
+
94
+ # fix the seed for reproducibility
95
+ seed = args.seed + misc.get_rank()
96
+ torch.manual_seed(seed)
97
+ np.random.seed(seed)
98
+ cudnn.benchmark = True
99
+
100
+ # Metrics / criterion
101
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
102
+ metrics = (StereoMetrics if args.task=='stereo' else FlowMetrics)().to(device)
103
+ criterion = eval(args.criterion).to(device)
104
+ print('Criterion: ', args.criterion)
105
+
106
+ # Prepare model
107
+ assert os.path.isfile(args.pretrained)
108
+ ckpt = torch.load(args.pretrained, 'cpu')
109
+ croco_args = croco_args_from_ckpt(ckpt)
110
+ croco_args['img_size'] = (args.crop[0], args.crop[1])
111
+ print('Croco args: '+str(croco_args))
112
+ args.croco_args = croco_args # saved for test time
113
+ # prepare head
114
+ num_channels = {'stereo': 1, 'flow': 2}[args.task]
115
+ if criterion.with_conf: num_channels += 1
116
+ print(f'Building head PixelwiseTaskWithDPT() with {num_channels} channel(s)')
117
+ head = PixelwiseTaskWithDPT()
118
+ head.num_channels = num_channels
119
+ # build model and load pretrained weights
120
+ model = CroCoDownstreamBinocular(head, **croco_args)
121
+ interpolate_pos_embed(model, ckpt['model'])
122
+ msg = model.load_state_dict(ckpt['model'], strict=False)
123
+ print(msg)
124
+
125
+ total_params = sum(p.numel() for p in model.parameters())
126
+ total_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
127
+ print(f"Total params: {total_params}")
128
+ print(f"Total params trainable: {total_params_trainable}")
129
+ model_without_ddp = model.to(device)
130
+
131
+ eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()
132
+ print("lr: %.2e" % args.lr)
133
+ print("accumulate grad iterations: %d" % args.accum_iter)
134
+ print("effective batch size: %d" % eff_batch_size)
135
+
136
+ if args.distributed:
137
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], static_graph=True)
138
+ model_without_ddp = model.module
139
+
140
+ # following timm: set wd as 0 for bias and norm layers
141
+ param_groups = misc.get_parameter_groups(model_without_ddp, args.weight_decay)
142
+ optimizer = eval(f"torch.optim.{args.optimizer}")
143
+ print(optimizer)
144
+ loss_scaler = NativeScaler()
145
+
146
+ # automatic restart
147
+ last_ckpt_fname = os.path.join(args.output_dir, f'checkpoint-last.pth')
148
+ args.resume = last_ckpt_fname if os.path.isfile(last_ckpt_fname) else None
149
+
150
+ if not args.resume and args.start_from:
151
+ print(f"Starting from an other model's weights: {args.start_from}")
152
+ best_so_far = None
153
+ args.start_epoch = 0
154
+ ckpt = torch.load(args.start_from, 'cpu')
155
+ msg = model_without_ddp.load_state_dict(ckpt['model'], strict=False)
156
+ print(msg)
157
+ else:
158
+ best_so_far = misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)
159
+
160
+ if best_so_far is None: best_so_far = np.inf
161
+
162
+ # tensorboard
163
+ log_writer = None
164
+ if global_rank == 0 and args.output_dir is not None:
165
+ log_writer = SummaryWriter(log_dir=args.output_dir, purge_step=args.start_epoch*1000)
166
+
167
+ # dataset and loader
168
+ print('Building Train Data loader for dataset: ', args.dataset)
169
+ train_dataset = (get_train_dataset_stereo if args.task=='stereo' else get_train_dataset_flow)(args.dataset, crop_size=args.crop)
170
+ def _print_repr_dataset(d):
171
+ if isinstance(d, torch.utils.data.dataset.ConcatDataset):
172
+ for dd in d.datasets:
173
+ _print_repr_dataset(dd)
174
+ else:
175
+ print(repr(d))
176
+ _print_repr_dataset(train_dataset)
177
+ print(' total length:', len(train_dataset))
178
+ if args.distributed:
179
+ sampler_train = torch.utils.data.DistributedSampler(
180
+ train_dataset, num_replicas=num_tasks, rank=global_rank, shuffle=True
181
+ )
182
+ else:
183
+ sampler_train = torch.utils.data.RandomSampler(train_dataset)
184
+ data_loader_train = torch.utils.data.DataLoader(
185
+ train_dataset, sampler=sampler_train,
186
+ batch_size=args.batch_size,
187
+ num_workers=args.num_workers,
188
+ pin_memory=True,
189
+ drop_last=True,
190
+ )
191
+ if args.val_dataset=='':
192
+ data_loaders_val = None
193
+ else:
194
+ print('Building Val Data loader for datasets: ', args.val_dataset)
195
+ val_datasets = (get_test_datasets_stereo if args.task=='stereo' else get_test_datasets_flow)(args.val_dataset)
196
+ for val_dataset in val_datasets: print(repr(val_dataset))
197
+ data_loaders_val = [DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=args.num_workers, pin_memory=True, drop_last=False) for val_dataset in val_datasets]
198
+ bestmetric = ("AVG_" if len(data_loaders_val)>1 else str(data_loaders_val[0].dataset)+'_')+args.bestmetric
199
+
200
+ print(f"Start training for {args.epochs} epochs")
201
+ start_time = time.time()
202
+ # Training Loop
203
+ for epoch in range(args.start_epoch, args.epochs):
204
+
205
+ if args.distributed: data_loader_train.sampler.set_epoch(epoch)
206
+
207
+ # Train
208
+ epoch_start = time.time()
209
+ train_stats = train_one_epoch(model, criterion, metrics, data_loader_train, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args)
210
+ epoch_time = time.time() - epoch_start
211
+
212
+ if args.distributed: dist.barrier()
213
+
214
+ # Validation (current naive implementation runs the validation on every gpu ... not smart ...)
215
+ if data_loaders_val is not None and args.eval_every > 0 and (epoch+1) % args.eval_every == 0:
216
+ val_epoch_start = time.time()
217
+ val_stats = validate_one_epoch(model, criterion, metrics, data_loaders_val, device, epoch, log_writer=log_writer, args=args)
218
+ val_epoch_time = time.time() - val_epoch_start
219
+
220
+ val_best = val_stats[bestmetric]
221
+
222
+ # Save best of all
223
+ if val_best <= best_so_far:
224
+ best_so_far = val_best
225
+ misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, best_so_far=best_so_far, fname='best')
226
+
227
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
228
+ 'epoch': epoch,
229
+ **{f'val_{k}': v for k, v in val_stats.items()}}
230
+ else:
231
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
232
+ 'epoch': epoch,}
233
+
234
+ if args.distributed: dist.barrier()
235
+
236
+ # Save stuff
237
+ if args.output_dir and ((epoch+1) % args.save_every == 0 or epoch + 1 == args.epochs):
238
+ misc.save_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler, epoch=epoch, best_so_far=best_so_far, fname='last')
239
+
240
+ if args.output_dir:
241
+ if log_writer is not None:
242
+ log_writer.flush()
243
+ with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
244
+ f.write(json.dumps(log_stats) + "\n")
245
+
246
+ total_time = time.time() - start_time
247
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
248
+ print('Training time {}'.format(total_time_str))
249
+
250
+ if __name__ == '__main__':
251
+ args = get_args_parser()
252
+ args = args.parse_args()
253
+ main(args)
third_party/dust3r/croco/utils/misc.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2022-present Naver Corporation. All rights reserved.
2
+ # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
+ #
4
+ # --------------------------------------------------------
5
+ # utilitary functions for CroCo
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # MAE: https://github.com/facebookresearch/mae
9
+ # DeiT: https://github.com/facebookresearch/deit
10
+ # BEiT: https://github.com/microsoft/unilm/tree/master/beit
11
+ # --------------------------------------------------------
12
+
13
+ import builtins
14
+ import datetime
15
+ import os
16
+ import time
17
+ import math
18
+ import json
19
+ from collections import defaultdict, deque
20
+ from pathlib import Path
21
+ import numpy as np
22
+
23
+ import torch
24
+ import torch.distributed as dist
25
+ from torch import inf
26
+
27
+ class SmoothedValue(object):
28
+ """Track a series of values and provide access to smoothed values over a
29
+ window or the global series average.
30
+ """
31
+
32
+ def __init__(self, window_size=20, fmt=None):
33
+ if fmt is None:
34
+ fmt = "{median:.4f} ({global_avg:.4f})"
35
+ self.deque = deque(maxlen=window_size)
36
+ self.total = 0.0
37
+ self.count = 0
38
+ self.fmt = fmt
39
+
40
+ def update(self, value, n=1):
41
+ self.deque.append(value)
42
+ self.count += n
43
+ self.total += value * n
44
+
45
+ def synchronize_between_processes(self):
46
+ """
47
+ Warning: does not synchronize the deque!
48
+ """
49
+ if not is_dist_avail_and_initialized():
50
+ return
51
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
52
+ dist.barrier()
53
+ dist.all_reduce(t)
54
+ t = t.tolist()
55
+ self.count = int(t[0])
56
+ self.total = t[1]
57
+
58
+ @property
59
+ def median(self):
60
+ d = torch.tensor(list(self.deque))
61
+ return d.median().item()
62
+
63
+ @property
64
+ def avg(self):
65
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
66
+ return d.mean().item()
67
+
68
+ @property
69
+ def global_avg(self):
70
+ return self.total / self.count
71
+
72
+ @property
73
+ def max(self):
74
+ return max(self.deque)
75
+
76
+ @property
77
+ def value(self):
78
+ return self.deque[-1]
79
+
80
+ def __str__(self):
81
+ return self.fmt.format(
82
+ median=self.median,
83
+ avg=self.avg,
84
+ global_avg=self.global_avg,
85
+ max=self.max,
86
+ value=self.value)
87
+
88
+
89
+ class MetricLogger(object):
90
+ def __init__(self, delimiter="\t"):
91
+ self.meters = defaultdict(SmoothedValue)
92
+ self.delimiter = delimiter
93
+
94
+ def update(self, **kwargs):
95
+ for k, v in kwargs.items():
96
+ if v is None:
97
+ continue
98
+ if isinstance(v, torch.Tensor):
99
+ v = v.item()
100
+ assert isinstance(v, (float, int))
101
+ self.meters[k].update(v)
102
+
103
+ def __getattr__(self, attr):
104
+ if attr in self.meters:
105
+ return self.meters[attr]
106
+ if attr in self.__dict__:
107
+ return self.__dict__[attr]
108
+ raise AttributeError("'{}' object has no attribute '{}'".format(
109
+ type(self).__name__, attr))
110
+
111
+ def __str__(self):
112
+ loss_str = []
113
+ for name, meter in self.meters.items():
114
+ loss_str.append(
115
+ "{}: {}".format(name, str(meter))
116
+ )
117
+ return self.delimiter.join(loss_str)
118
+
119
+ def synchronize_between_processes(self):
120
+ for meter in self.meters.values():
121
+ meter.synchronize_between_processes()
122
+
123
+ def add_meter(self, name, meter):
124
+ self.meters[name] = meter
125
+
126
+ def log_every(self, iterable, print_freq, header=None, max_iter=None):
127
+ i = 0
128
+ if not header:
129
+ header = ''
130
+ start_time = time.time()
131
+ end = time.time()
132
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
133
+ data_time = SmoothedValue(fmt='{avg:.4f}')
134
+ len_iterable = min(len(iterable), max_iter) if max_iter else len(iterable)
135
+ space_fmt = ':' + str(len(str(len_iterable))) + 'd'
136
+ log_msg = [
137
+ header,
138
+ '[{0' + space_fmt + '}/{1}]',
139
+ 'eta: {eta}',
140
+ '{meters}',
141
+ 'time: {time}',
142
+ 'data: {data}'
143
+ ]
144
+ if torch.cuda.is_available():
145
+ log_msg.append('max mem: {memory:.0f}')
146
+ log_msg = self.delimiter.join(log_msg)
147
+ MB = 1024.0 * 1024.0
148
+ for it,obj in enumerate(iterable):
149
+ data_time.update(time.time() - end)
150
+ yield obj
151
+ iter_time.update(time.time() - end)
152
+ if i % print_freq == 0 or i == len_iterable - 1:
153
+ eta_seconds = iter_time.global_avg * (len_iterable - i)
154
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
155
+ if torch.cuda.is_available():
156
+ print(log_msg.format(
157
+ i, len_iterable, eta=eta_string,
158
+ meters=str(self),
159
+ time=str(iter_time), data=str(data_time),
160
+ memory=torch.cuda.max_memory_allocated() / MB))
161
+ else:
162
+ print(log_msg.format(
163
+ i, len_iterable, eta=eta_string,
164
+ meters=str(self),
165
+ time=str(iter_time), data=str(data_time)))
166
+ i += 1
167
+ end = time.time()
168
+ if max_iter and it >= max_iter:
169
+ break
170
+ total_time = time.time() - start_time
171
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
172
+ print('{} Total time: {} ({:.4f} s / it)'.format(
173
+ header, total_time_str, total_time / len_iterable))
174
+
175
+
176
+ def setup_for_distributed(is_master):
177
+ """
178
+ This function disables printing when not in master process
179
+ """
180
+ builtin_print = builtins.print
181
+
182
+ def print(*args, **kwargs):
183
+ force = kwargs.pop('force', False)
184
+ force = force or (get_world_size() > 8)
185
+ if is_master or force:
186
+ now = datetime.datetime.now().time()
187
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
188
+ builtin_print(*args, **kwargs)
189
+
190
+ builtins.print = print
191
+
192
+
193
+ def is_dist_avail_and_initialized():
194
+ if not dist.is_available():
195
+ return False
196
+ if not dist.is_initialized():
197
+ return False
198
+ return True
199
+
200
+
201
+ def get_world_size():
202
+ if not is_dist_avail_and_initialized():
203
+ return 1
204
+ return dist.get_world_size()
205
+
206
+
207
+ def get_rank():
208
+ if not is_dist_avail_and_initialized():
209
+ return 0
210
+ return dist.get_rank()
211
+
212
+
213
+ def is_main_process():
214
+ return get_rank() == 0
215
+
216
+
217
+ def save_on_master(*args, **kwargs):
218
+ if is_main_process():
219
+ torch.save(*args, **kwargs)
220
+
221
+
222
+ def init_distributed_mode(args):
223
+ nodist = args.nodist if hasattr(args,'nodist') else False
224
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ and not nodist:
225
+ args.rank = int(os.environ["RANK"])
226
+ args.world_size = int(os.environ['WORLD_SIZE'])
227
+ args.gpu = int(os.environ['LOCAL_RANK'])
228
+ else:
229
+ print('Not using distributed mode')
230
+ setup_for_distributed(is_master=True) # hack
231
+ args.distributed = False
232
+ return
233
+
234
+ args.distributed = True
235
+
236
+ torch.cuda.set_device(args.gpu)
237
+ args.dist_backend = 'nccl'
238
+ print('| distributed init (rank {}): {}, gpu {}'.format(
239
+ args.rank, args.dist_url, args.gpu), flush=True)
240
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
241
+ world_size=args.world_size, rank=args.rank)
242
+ torch.distributed.barrier()
243
+ setup_for_distributed(args.rank == 0)
244
+
245
+
246
+ class NativeScalerWithGradNormCount:
247
+ state_dict_key = "amp_scaler"
248
+
249
+ def __init__(self, enabled=True):
250
+ self._scaler = torch.cuda.amp.GradScaler(enabled=enabled)
251
+
252
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
253
+ self._scaler.scale(loss).backward(create_graph=create_graph)
254
+ if update_grad:
255
+ if clip_grad is not None:
256
+ assert parameters is not None
257
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
258
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
259
+ else:
260
+ self._scaler.unscale_(optimizer)
261
+ norm = get_grad_norm_(parameters)
262
+ self._scaler.step(optimizer)
263
+ self._scaler.update()
264
+ else:
265
+ norm = None
266
+ return norm
267
+
268
+ def state_dict(self):
269
+ return self._scaler.state_dict()
270
+
271
+ def load_state_dict(self, state_dict):
272
+ self._scaler.load_state_dict(state_dict)
273
+
274
+
275
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
276
+ if isinstance(parameters, torch.Tensor):
277
+ parameters = [parameters]
278
+ parameters = [p for p in parameters if p.grad is not None]
279
+ norm_type = float(norm_type)
280
+ if len(parameters) == 0:
281
+ return torch.tensor(0.)
282
+ device = parameters[0].grad.device
283
+ if norm_type == inf:
284
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
285
+ else:
286
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
287
+ return total_norm
288
+
289
+
290
+
291
+
292
+ def save_model(args, epoch, model_without_ddp, optimizer, loss_scaler, fname=None, best_so_far=None):
293
+ output_dir = Path(args.output_dir)
294
+ if fname is None: fname = str(epoch)
295
+ checkpoint_path = output_dir / ('checkpoint-%s.pth' % fname)
296
+ to_save = {
297
+ 'model': model_without_ddp.state_dict(),
298
+ 'optimizer': optimizer.state_dict(),
299
+ 'scaler': loss_scaler.state_dict(),
300
+ 'args': args,
301
+ 'epoch': epoch,
302
+ }
303
+ if best_so_far is not None: to_save['best_so_far'] = best_so_far
304
+ print(f'>> Saving model to {checkpoint_path} ...')
305
+ save_on_master(to_save, checkpoint_path)
306
+
307
+
308
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
309
+ args.start_epoch = 0
310
+ best_so_far = None
311
+ if args.resume is not None:
312
+ if args.resume.startswith('https'):
313
+ checkpoint = torch.hub.load_state_dict_from_url(
314
+ args.resume, map_location='cpu', check_hash=True)
315
+ else:
316
+ checkpoint = torch.load(args.resume, map_location='cpu')
317
+ print("Resume checkpoint %s" % args.resume)
318
+ model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
319
+ args.start_epoch = checkpoint['epoch'] + 1
320
+ optimizer.load_state_dict(checkpoint['optimizer'])
321
+ if 'scaler' in checkpoint:
322
+ loss_scaler.load_state_dict(checkpoint['scaler'])
323
+ if 'best_so_far' in checkpoint:
324
+ best_so_far = checkpoint['best_so_far']
325
+ print(" & best_so_far={:g}".format(best_so_far))
326
+ else:
327
+ print("")
328
+ print("With optim & sched! start_epoch={:d}".format(args.start_epoch), end='')
329
+ return best_so_far
330
+
331
+ def all_reduce_mean(x):
332
+ world_size = get_world_size()
333
+ if world_size > 1:
334
+ x_reduce = torch.tensor(x).cuda()
335
+ dist.all_reduce(x_reduce)
336
+ x_reduce /= world_size
337
+ return x_reduce.item()
338
+ else:
339
+ return x
340
+
341
+ def _replace(text, src, tgt, rm=''):
342
+ """ Advanced string replacement.
343
+ Given a text:
344
+ - replace all elements in src by the corresponding element in tgt
345
+ - remove all elements in rm
346
+ """
347
+ if len(tgt) == 1:
348
+ tgt = tgt * len(src)
349
+ assert len(src) == len(tgt), f"'{src}' and '{tgt}' should have the same len"
350
+ for s,t in zip(src, tgt):
351
+ text = text.replace(s,t)
352
+ for c in rm:
353
+ text = text.replace(c,'')
354
+ return text
355
+
356
+ def filename( obj ):
357
+ """ transform a python obj or cmd into a proper filename.
358
+ - \1 gets replaced by slash '/'
359
+ - \2 gets replaced by comma ','
360
+ """
361
+ if not isinstance(obj, str):
362
+ obj = repr(obj)
363
+ obj = str(obj).replace('()','')
364
+ obj = _replace(obj, '_,(*/\1\2','-__x%/,', rm=' )\'"')
365
+ assert all(len(s) < 256 for s in obj.split(os.sep)), 'filename too long (>256 characters):\n'+obj
366
+ return obj
367
+
368
+ def _get_num_layer_for_vit(var_name, enc_depth, dec_depth):
369
+ if var_name in ("cls_token", "mask_token", "pos_embed", "global_tokens"):
370
+ return 0
371
+ elif var_name.startswith("patch_embed"):
372
+ return 0
373
+ elif var_name.startswith("enc_blocks"):
374
+ layer_id = int(var_name.split('.')[1])
375
+ return layer_id + 1
376
+ elif var_name.startswith('decoder_embed') or var_name.startswith('enc_norm'): # part of the last black
377
+ return enc_depth
378
+ elif var_name.startswith('dec_blocks'):
379
+ layer_id = int(var_name.split('.')[1])
380
+ return enc_depth + layer_id + 1
381
+ elif var_name.startswith('dec_norm'): # part of the last block
382
+ return enc_depth + dec_depth
383
+ elif any(var_name.startswith(k) for k in ['head','prediction_head']):
384
+ return enc_depth + dec_depth + 1
385
+ else:
386
+ raise NotImplementedError(var_name)
387
+
388
+ def get_parameter_groups(model, weight_decay, layer_decay=1.0, skip_list=(), no_lr_scale_list=[]):
389
+ parameter_group_names = {}
390
+ parameter_group_vars = {}
391
+ enc_depth, dec_depth = None, None
392
+ # prepare layer decay values
393
+ assert layer_decay==1.0 or 0.<layer_decay<1.
394
+ if layer_decay<1.:
395
+ enc_depth = model.enc_depth
396
+ dec_depth = model.dec_depth if hasattr(model, 'dec_blocks') else 0
397
+ num_layers = enc_depth+dec_depth
398
+ layer_decay_values = list(layer_decay ** (num_layers + 1 - i) for i in range(num_layers + 2))
399
+
400
+ for name, param in model.named_parameters():
401
+ if not param.requires_grad:
402
+ continue # frozen weights
403
+
404
+ # Assign weight decay values
405
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list:
406
+ group_name = "no_decay"
407
+ this_weight_decay = 0.
408
+ else:
409
+ group_name = "decay"
410
+ this_weight_decay = weight_decay
411
+
412
+ # Assign layer ID for LR scaling
413
+ if layer_decay<1.:
414
+ skip_scale = False
415
+ layer_id = _get_num_layer_for_vit(name, enc_depth, dec_depth)
416
+ group_name = "layer_%d_%s" % (layer_id, group_name)
417
+ if name in no_lr_scale_list:
418
+ skip_scale = True
419
+ group_name = f'{group_name}_no_lr_scale'
420
+ else:
421
+ layer_id = 0
422
+ skip_scale = True
423
+
424
+ if group_name not in parameter_group_names:
425
+ if not skip_scale:
426
+ scale = layer_decay_values[layer_id]
427
+ else:
428
+ scale = 1.
429
+
430
+ parameter_group_names[group_name] = {
431
+ "weight_decay": this_weight_decay,
432
+ "params": [],
433
+ "lr_scale": scale
434
+ }
435
+ parameter_group_vars[group_name] = {
436
+ "weight_decay": this_weight_decay,
437
+ "params": [],
438
+ "lr_scale": scale
439
+ }
440
+
441
+ parameter_group_vars[group_name]["params"].append(param)
442
+ parameter_group_names[group_name]["params"].append(name)
443
+ print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
444
+ return list(parameter_group_vars.values())
445
+
446
+
447
+
448
+ def adjust_learning_rate(optimizer, epoch, args):
449
+ """Decay the learning rate with half-cycle cosine after warmup"""
450
+
451
+ if epoch < args.warmup_epochs:
452
+ lr = args.lr * epoch / args.warmup_epochs
453
+ else:
454
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
455
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
456
+
457
+ for param_group in optimizer.param_groups:
458
+ if "lr_scale" in param_group:
459
+ param_group["lr"] = lr * param_group["lr_scale"]
460
+ else:
461
+ param_group["lr"] = lr
462
+
463
+ return lr
third_party/dust3r/datasets_preprocess/habitat/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Steps to reproduce synthetic training data using the Habitat-Sim simulator
2
+
3
+ ### Create a conda environment
4
+ ```bash
5
+ conda create -n habitat python=3.8 habitat-sim=0.2.1 headless=2.0 -c aihabitat -c conda-forge
6
+ conda active habitat
7
+ conda install pytorch -c pytorch
8
+ pip install opencv-python tqdm
9
+ ```
10
+
11
+ or (if you get the error `For headless systems, compile with --headless for EGL support`)
12
+ ```
13
+ git clone --branch stable https://github.com/facebookresearch/habitat-sim.git
14
+ cd habitat-sim
15
+
16
+ conda create -n habitat python=3.9 cmake=3.14.0
17
+ conda activate habitat
18
+ pip install . -v
19
+ conda install pytorch -c pytorch
20
+ pip install opencv-python tqdm
21
+ ```
22
+
23
+ ### Download Habitat-Sim scenes
24
+ Download Habitat-Sim scenes:
25
+ - Download links can be found here: https://github.com/facebookresearch/habitat-sim/blob/main/DATASETS.md
26
+ - We used scenes from the HM3D, habitat-test-scenes, ReplicaCad and ScanNet datasets.
27
+ - Please put the scenes in a directory `$SCENES_DIR` following the structure below:
28
+ (Note: the habitat-sim dataset installer may install an incompatible version for ReplicaCAD backed lighting.
29
+ The correct scene dataset can be dowloaded from Huggingface: `git clone git@hf.co:datasets/ai-habitat/ReplicaCAD_baked_lighting`).
30
+ ```
31
+ $SCENES_DIR/
32
+ ├──hm3d/
33
+ ├──gibson/
34
+ ├──habitat-test-scenes/
35
+ ├──ReplicaCAD_baked_lighting/
36
+ └──scannet/
37
+ ```
38
+
39
+ ### Download renderings metadata
40
+
41
+ Download metadata corresponding to each scene and extract them into a directory `$METADATA_DIR`
42
+ ```bash
43
+ wget https://download.europe.naverlabs.com/ComputerVision/DUSt3R/habitat_5views_v1_512x512_metadata.tar.gz
44
+ tar -xvzf habitat_5views_v1_512x512_metadata.tar.gz
45
+ ```
46
+
47
+ ### Render the scenes
48
+
49
+ Render the scenes in an output directory `$OUTPUT_DIR`
50
+ ```bash
51
+ export METADATA_DIR="/path/to/habitat/5views_v1_512x512_metadata"
52
+ export SCENES_DIR="/path/to/habitat/data/scene_datasets/"
53
+ export OUTPUT_DIR="data/habitat_processed"
54
+ cd datasets_preprocess/habitat/
55
+ export PYTHONPATH=$(pwd)
56
+ # Print commandlines to generate images corresponding to each scene
57
+ python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR
58
+ # Launch these commandlines in parallel e.g. using GNU-Parallel as follows:
59
+ python preprocess_habitat.py --scenes_dir=$SCENES_DIR --metadata_dir=$METADATA_DIR --output_dir=$OUTPUT_DIR | parallel -j 16
60
+ ```
61
+
62
+ ### Make a list of scenes
63
+
64
+ ```bash
65
+ python find_scenes.py --root $OUTPUT_DIR
66
+ ```