Realcat commited on
Commit
10dcc2e
·
1 Parent(s): d21720c

add: COTR(https://github.com/ubc-vision/COTR)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +1 -0
  2. common/app_class.py +3 -1
  3. common/config.yaml +10 -0
  4. common/utils.py +33 -5
  5. env-docker.txt +2 -2
  6. hloc/match_dense.py +38 -0
  7. hloc/matchers/cotr.py +77 -0
  8. third_party/COTR/.gitignore +1 -0
  9. third_party/COTR/COTR/cameras/camera_pose.py +164 -0
  10. third_party/COTR/COTR/cameras/capture.py +432 -0
  11. third_party/COTR/COTR/cameras/pinhole_camera.py +73 -0
  12. third_party/COTR/COTR/datasets/colmap_helper.py +312 -0
  13. third_party/COTR/COTR/datasets/cotr_dataset.py +243 -0
  14. third_party/COTR/COTR/datasets/megadepth_dataset.py +140 -0
  15. third_party/COTR/COTR/global_configs/__init__.py +10 -0
  16. third_party/COTR/COTR/global_configs/commons.json +1 -0
  17. third_party/COTR/COTR/global_configs/dataset_config.json +41 -0
  18. third_party/COTR/COTR/inference/inference_helper.py +311 -0
  19. third_party/COTR/COTR/inference/refinement_task.py +191 -0
  20. third_party/COTR/COTR/inference/sparse_engine.py +427 -0
  21. third_party/COTR/COTR/models/__init__.py +10 -0
  22. third_party/COTR/COTR/models/backbone.py +135 -0
  23. third_party/COTR/COTR/models/cotr_model.py +51 -0
  24. third_party/COTR/COTR/models/misc.py +112 -0
  25. third_party/COTR/COTR/models/position_encoding.py +83 -0
  26. third_party/COTR/COTR/models/transformer.py +228 -0
  27. third_party/COTR/COTR/options/options.py +52 -0
  28. third_party/COTR/COTR/options/options_utils.py +108 -0
  29. third_party/COTR/COTR/projector/pcd_projector.py +210 -0
  30. third_party/COTR/COTR/sfm_scenes/knn_search.py +56 -0
  31. third_party/COTR/COTR/sfm_scenes/sfm_scenes.py +87 -0
  32. third_party/COTR/COTR/trainers/base_trainer.py +111 -0
  33. third_party/COTR/COTR/trainers/cotr_trainer.py +200 -0
  34. third_party/COTR/COTR/trainers/tensorboard_helper.py +97 -0
  35. third_party/COTR/COTR/transformations/transform_basics.py +114 -0
  36. third_party/COTR/COTR/transformations/transformations.py +1951 -0
  37. third_party/COTR/COTR/utils/constants.py +3 -0
  38. third_party/COTR/COTR/utils/debug_utils.py +15 -0
  39. third_party/COTR/COTR/utils/utils.py +271 -0
  40. third_party/COTR/LICENSE +201 -0
  41. third_party/COTR/demo_face.py +69 -0
  42. third_party/COTR/demo_guided_matching.py +85 -0
  43. third_party/COTR/demo_homography.py +84 -0
  44. third_party/COTR/demo_reconstruction.py +92 -0
  45. third_party/COTR/demo_single_pair.py +66 -0
  46. third_party/COTR/demo_wbs.py +71 -0
  47. third_party/COTR/environment.yml +104 -0
  48. third_party/COTR/out/.DS_Store +0 -0
  49. third_party/COTR/out/.placeholder +0 -0
  50. third_party/COTR/out/default/checkpoint.pth.tar +3 -0
README.md CHANGED
@@ -56,6 +56,7 @@ The tool currently supports various popular image matching algorithms, namely:
56
  - [x] [LANet](https://github.com/wangch-g/lanet), ACCV 2022
57
  - [ ] [LISRD](https://github.com/rpautrat/LISRD), ECCV 2022
58
  - [ ] [REKD](https://github.com/bluedream1121/REKD), CVPR 2022
 
59
  - [x] [ALIKE](https://github.com/Shiaoming/ALIKE), TMM 2022
60
  - [x] [RoRD](https://github.com/UditSinghParihar/RoRD), IROS 2021
61
  - [x] [SGMNet](https://github.com/vdvchen/SGMNet), ICCV 2021
 
56
  - [x] [LANet](https://github.com/wangch-g/lanet), ACCV 2022
57
  - [ ] [LISRD](https://github.com/rpautrat/LISRD), ECCV 2022
58
  - [ ] [REKD](https://github.com/bluedream1121/REKD), CVPR 2022
59
+ - [x] [CoTR](https://github.com/ubc-vision/COTR), ICCV 2021
60
  - [x] [ALIKE](https://github.com/Shiaoming/ALIKE), TMM 2022
61
  - [x] [RoRD](https://github.com/UditSinghParihar/RoRD), IROS 2021
62
  - [x] [SGMNet](https://github.com/vdvchen/SGMNet), ICCV 2021
common/app_class.py CHANGED
@@ -300,6 +300,7 @@ class ImageMatchingApp:
300
  fn=run_ransac,
301
  inputs=[
302
  state_cache,
 
303
  ransac_method,
304
  ransac_reproj_threshold,
305
  ransac_confidence,
@@ -308,6 +309,7 @@ class ImageMatchingApp:
308
  outputs=[
309
  output_matches_ransac,
310
  matches_result_info,
 
311
  ],
312
  )
313
 
@@ -457,7 +459,7 @@ class ImageMatchingApp:
457
  return gr.Markdown(markdown_table)
458
  elif style == "tab":
459
  for k, v in cfg.items():
460
- if not v["info"]["display"]:
461
  continue
462
  data.append(
463
  [
 
300
  fn=run_ransac,
301
  inputs=[
302
  state_cache,
303
+ choice_geometry_type,
304
  ransac_method,
305
  ransac_reproj_threshold,
306
  ransac_confidence,
 
309
  outputs=[
310
  output_matches_ransac,
311
  matches_result_info,
312
+ output_wrapped,
313
  ],
314
  )
315
 
 
459
  return gr.Markdown(markdown_table)
460
  elif style == "tab":
461
  for k, v in cfg.items():
462
+ if not v["info"].get("display", True):
463
  continue
464
  data.append(
465
  [
common/config.yaml CHANGED
@@ -46,6 +46,16 @@ matcher_zoo:
46
  paper: https://arxiv.org/pdf/2104.00680
47
  project: https://zju3dv.github.io/loftr
48
  display: true
 
 
 
 
 
 
 
 
 
 
49
  topicfm:
50
  matcher: topicfm
51
  dense: true
 
46
  paper: https://arxiv.org/pdf/2104.00680
47
  project: https://zju3dv.github.io/loftr
48
  display: true
49
+ cotr:
50
+ matcher: cotr
51
+ dense: true
52
+ info:
53
+ name: CoTR #dispaly name
54
+ source: "ICCV 2021"
55
+ github: https://github.com/ubc-vision/COTR
56
+ paper: https://arxiv.org/abs/2103.14167
57
+ project: null
58
+ display: true
59
  topicfm:
60
  matcher: topicfm
61
  dense: true
common/utils.py CHANGED
@@ -443,6 +443,7 @@ def generate_warp_images(
443
 
444
  def run_ransac(
445
  state_cache: Dict[str, Any],
 
446
  ransac_method: str = DEFAULT_RANSAC_METHOD,
447
  ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
448
  ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
@@ -493,11 +494,32 @@ def run_ransac(
493
  )
494
  logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
495
  t1 = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
  num_matches_raw = state_cache["num_matches_raw"]
497
- return output_matches_ransac, {
498
- "num_matches_raw": num_matches_raw,
499
- "num_matches_ransac": num_matches_ransac,
500
- }
 
 
 
 
501
 
502
 
503
  def run_matching(
@@ -666,7 +688,13 @@ def run_matching(
666
 
667
  t1 = time.time()
668
  # plot wrapped images
669
- geom_info = compute_geometry(pred)
 
 
 
 
 
 
670
  output_wrapped, _ = generate_warp_images(
671
  pred["image0_orig"],
672
  pred["image1_orig"],
 
443
 
444
  def run_ransac(
445
  state_cache: Dict[str, Any],
446
+ choice_geometry_type: str,
447
  ransac_method: str = DEFAULT_RANSAC_METHOD,
448
  ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
449
  ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
 
494
  )
495
  logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
496
  t1 = time.time()
497
+
498
+ # compute warp images
499
+ geom_info = compute_geometry(
500
+ state_cache,
501
+ ransac_method=ransac_method,
502
+ ransac_reproj_threshold=ransac_reproj_threshold,
503
+ ransac_confidence=ransac_confidence,
504
+ ransac_max_iter=ransac_max_iter,
505
+ )
506
+ output_wrapped, _ = generate_warp_images(
507
+ state_cache["image0_orig"],
508
+ state_cache["image1_orig"],
509
+ {"geom_info": geom_info},
510
+ choice_geometry_type,
511
+ )
512
+ plt.close("all")
513
+
514
  num_matches_raw = state_cache["num_matches_raw"]
515
+ return (
516
+ output_matches_ransac,
517
+ {
518
+ "num_matches_raw": num_matches_raw,
519
+ "num_matches_ransac": num_matches_ransac,
520
+ },
521
+ output_wrapped,
522
+ )
523
 
524
 
525
  def run_matching(
 
688
 
689
  t1 = time.time()
690
  # plot wrapped images
691
+ geom_info = compute_geometry(
692
+ pred,
693
+ ransac_method=ransac_method,
694
+ ransac_reproj_threshold=ransac_reproj_threshold,
695
+ ransac_confidence=ransac_confidence,
696
+ ransac_max_iter=ransac_max_iter,
697
+ )
698
  output_wrapped, _ = generate_warp_images(
699
  pred["image0_orig"],
700
  pred["image1_orig"],
env-docker.txt CHANGED
@@ -1,8 +1,8 @@
1
  e2cnn==0.2.3
2
  einops==0.6.1
3
  gdown==4.7.1
4
- gradio==3.41.2
5
- gradio_client==0.5.0
6
  h5py==3.9.0
7
  imageio==2.31.1
8
  Jinja2==3.1.2
 
1
  e2cnn==0.2.3
2
  einops==0.6.1
3
  gdown==4.7.1
4
+ gradio==4.28.3
5
+ gradio_client==0.16.0
6
  h5py==3.9.0
7
  imageio==2.31.1
8
  Jinja2==3.1.2
hloc/match_dense.py CHANGED
@@ -28,6 +28,44 @@ confs = {
28
  "max_error": 1, # max error for assigned keypoints (in px)
29
  "cell_size": 1, # size of quantization patch (max 1 kp/patch)
30
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  # Semi-scalable loftr which limits detected keypoints
32
  "loftr_aachen": {
33
  "output": "matches-loftr_aachen",
 
28
  "max_error": 1, # max error for assigned keypoints (in px)
29
  "cell_size": 1, # size of quantization patch (max 1 kp/patch)
30
  },
31
+ # "loftr_quadtree": {
32
+ # "output": "matches-loftr-quadtree",
33
+ # "model": {
34
+ # "name": "quadtree",
35
+ # "weights": "outdoor",
36
+ # "max_keypoints": 2000,
37
+ # "match_threshold": 0.2,
38
+ # },
39
+ # "preprocessing": {
40
+ # "grayscale": True,
41
+ # "resize_max": 1024,
42
+ # "dfactor": 8,
43
+ # "width": 640,
44
+ # "height": 480,
45
+ # "force_resize": True,
46
+ # },
47
+ # "max_error": 1, # max error for assigned keypoints (in px)
48
+ # "cell_size": 1, # size of quantization patch (max 1 kp/patch)
49
+ # },
50
+ "cotr": {
51
+ "output": "matches-cotr",
52
+ "model": {
53
+ "name": "cotr",
54
+ "weights": "out/default",
55
+ "max_keypoints": 2000,
56
+ "match_threshold": 0.2,
57
+ },
58
+ "preprocessing": {
59
+ "grayscale": False,
60
+ "resize_max": 1024,
61
+ "dfactor": 8,
62
+ "width": 640,
63
+ "height": 480,
64
+ "force_resize": True,
65
+ },
66
+ "max_error": 1, # max error for assigned keypoints (in px)
67
+ "cell_size": 1, # size of quantization patch (max 1 kp/patch)
68
+ },
69
  # Semi-scalable loftr which limits detected keypoints
70
  "loftr_aachen": {
71
  "output": "matches-loftr_aachen",
hloc/matchers/cotr.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import torch
4
+ import warnings
5
+ import numpy as np
6
+ from pathlib import Path
7
+ from torchvision.transforms import ToPILImage
8
+ from ..utils.base_model import BaseModel
9
+
10
+ sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
11
+ from COTR.utils import utils as utils_cotr
12
+ from COTR.models import build_model
13
+ from COTR.options.options import *
14
+ from COTR.options.options_utils import *
15
+ from COTR.inference.inference_helper import triangulate_corr
16
+ from COTR.inference.sparse_engine import SparseEngine
17
+
18
+ utils_cotr.fix_randomness(0)
19
+ torch.set_grad_enabled(False)
20
+
21
+ cotr_path = Path(__file__).parent / "../../third_party/COTR"
22
+
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+
25
+
26
+ class COTR(BaseModel):
27
+ default_conf = {
28
+ "weights": "out/default",
29
+ "match_threshold": 0.2,
30
+ "max_keypoints": -1,
31
+ }
32
+ required_inputs = ["image0", "image1"]
33
+
34
+ def _init(self, conf):
35
+ parser = argparse.ArgumentParser()
36
+ set_COTR_arguments(parser)
37
+ opt = parser.parse_args()
38
+ opt.command = " ".join(sys.argv)
39
+ opt.load_weights_path = str(
40
+ cotr_path / conf["weights"] / "checkpoint.pth.tar"
41
+ )
42
+
43
+ layer_2_channels = {
44
+ "layer1": 256,
45
+ "layer2": 512,
46
+ "layer3": 1024,
47
+ "layer4": 2048,
48
+ }
49
+ opt.dim_feedforward = layer_2_channels[opt.layer]
50
+
51
+ model = build_model(opt)
52
+ model = model.to(device)
53
+ weights = torch.load(opt.load_weights_path, map_location="cpu")[
54
+ "model_state_dict"
55
+ ]
56
+ utils_cotr.safe_load_weights(model, weights)
57
+ self.net = model.eval()
58
+ self.to_pil_func = ToPILImage(mode="RGB")
59
+
60
+ def _forward(self, data):
61
+ img_a = np.array(self.to_pil_func(data["image0"][0].cpu()))
62
+ img_b = np.array(self.to_pil_func(data["image1"][0].cpu()))
63
+ corrs = SparseEngine(
64
+ self.net, 32, mode="tile"
65
+ ).cotr_corr_multiscale_with_cycle_consistency(
66
+ img_a,
67
+ img_b,
68
+ np.linspace(0.5, 0.0625, 4),
69
+ 1,
70
+ max_corrs=self.conf["max_keypoints"],
71
+ queries_a=None,
72
+ )
73
+ pred = {
74
+ "keypoints0": torch.from_numpy(corrs[:,:2]),
75
+ "keypoints1": torch.from_numpy(corrs[:,2:]),
76
+ }
77
+ return pred
third_party/COTR/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ *.pyc
third_party/COTR/COTR/cameras/camera_pose.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Extrinsic camera pose
3
+ '''
4
+ import math
5
+ import copy
6
+
7
+ import numpy as np
8
+
9
+ from COTR.transformations import transformations
10
+ from COTR.transformations.transform_basics import Translation, Rotation, UnstableRotation
11
+
12
+
13
+ class CameraPose():
14
+ def __init__(self, t: Translation, r: Rotation):
15
+ '''
16
+ WARN: World 2 cam
17
+ Translation and rotation are world to camera
18
+ translation_vector is not the coordinate of the camera in world space.
19
+ '''
20
+ assert isinstance(t, Translation)
21
+ assert isinstance(r, Rotation) or isinstance(r, UnstableRotation)
22
+ self.t = t
23
+ self.r = r
24
+
25
+ def __str__(self):
26
+ string = f'center in world: {self.camera_center_in_world}, translation(w2c): {self.t}, rotation(w2c): {self.r}'
27
+ return string
28
+
29
+ @classmethod
30
+ def from_world_to_camera(cls, world_to_camera, unstable=False):
31
+ assert isinstance(world_to_camera, np.ndarray)
32
+ assert world_to_camera.shape == (4, 4)
33
+ vec = transformations.translation_from_matrix(world_to_camera).astype(np.float32)
34
+ t = Translation(vec)
35
+ if unstable:
36
+ r = UnstableRotation(world_to_camera)
37
+ else:
38
+ quat = transformations.quaternion_from_matrix(world_to_camera).astype(np.float32)
39
+ r = Rotation(quat)
40
+ return cls(t, r)
41
+
42
+ @classmethod
43
+ def from_camera_to_world(cls, camera_to_world, unstable=False):
44
+ assert isinstance(camera_to_world, np.ndarray)
45
+ assert camera_to_world.shape == (4, 4)
46
+ world_to_camera = np.linalg.inv(camera_to_world)
47
+ world_to_camera /= world_to_camera[3, 3]
48
+ return cls.from_world_to_camera(world_to_camera, unstable)
49
+
50
+ @classmethod
51
+ def from_pose_vector(cls, pose_vector):
52
+ t = Translation(pose_vector[:3])
53
+ r = Rotation(pose_vector[3:])
54
+ return cls(t, r)
55
+
56
+ @property
57
+ def translation_vector(self):
58
+ return self.t.translation_vector
59
+
60
+ @property
61
+ def translation_matrix(self):
62
+ return self.t.translation_matrix
63
+
64
+ @property
65
+ def quaternion(self):
66
+ '''
67
+ quaternion format (w, x, y, z)
68
+ '''
69
+ return self.r.quaternion
70
+
71
+ @property
72
+ def rotation_matrix(self):
73
+ return self.r.rotation_matrix
74
+
75
+ @property
76
+ def pose_vector(self):
77
+ '''
78
+ Pose vector is a concat of translation vector and quaternion vector
79
+ (X, Y, Z, w, x, y, z)
80
+ w2c
81
+ '''
82
+ return np.concatenate([self.translation_vector, self.quaternion])
83
+
84
+ @property
85
+ def inv_pose_vector(self):
86
+ inv_quat = transformations.quaternion_inverse(self.quaternion)
87
+ return np.concatenate([self.camera_center_in_world, inv_quat])
88
+
89
+ @property
90
+ def pose_vector_6_dof(self):
91
+ '''
92
+ Here we assuming the quaternion is normalized and we remove the W component
93
+ (X, Y, Z, x, y, z)
94
+ '''
95
+ return np.concatenate([self.translation_vector, self.quaternion[1:]])
96
+
97
+ @property
98
+ def world_to_camera(self):
99
+ M = np.matmul(self.translation_matrix, self.rotation_matrix)
100
+ M /= M[3, 3]
101
+ return M
102
+
103
+ @property
104
+ def world_to_camera_3x4(self):
105
+ M = self.world_to_camera
106
+ M = M[0:3, 0:4]
107
+ return M
108
+
109
+ @property
110
+ def extrinsic_mat(self):
111
+ return self.world_to_camera_3x4
112
+
113
+ @property
114
+ def camera_to_world(self):
115
+ M = np.linalg.inv(self.world_to_camera)
116
+ M /= M[3, 3]
117
+ return M
118
+
119
+ @property
120
+ def camera_to_world_3x4(self):
121
+ M = self.camera_to_world
122
+ M = M[0:3, 0:4]
123
+ return M
124
+
125
+ @property
126
+ def camera_center_in_world(self):
127
+ return self.camera_to_world[:3, 3]
128
+
129
+ @property
130
+ def forward(self):
131
+ return self.camera_to_world[:3, 2]
132
+
133
+ @property
134
+ def up(self):
135
+ return self.camera_to_world[:3, 1]
136
+
137
+ @property
138
+ def right(self):
139
+ return self.camera_to_world[:3, 0]
140
+
141
+ @property
142
+ def essential_matrix(self):
143
+ E = np.cross(self.rotation_matrix[:3, :3], self.camera_center_in_world)
144
+ return E / np.linalg.norm(E)
145
+
146
+
147
+ def inverse_camera_pose(cam_pose: CameraPose):
148
+ return CameraPose.from_world_to_camera(np.linalg.inv(cam_pose.world_to_camera))
149
+
150
+
151
+ def rotate_camera_pose(cam_pose, rot):
152
+ if rot == 0:
153
+ return copy.deepcopy(cam_pose)
154
+ else:
155
+ rot = rot / 180 * np.pi
156
+ sin_rot = np.sin(rot)
157
+ cos_rot = np.cos(rot)
158
+
159
+ rot_mat = np.stack([np.stack([cos_rot, -sin_rot, 0, 0], axis=-1),
160
+ np.stack([sin_rot, cos_rot, 0, 0], axis=-1),
161
+ np.stack([0, 0, 1, 0], axis=-1),
162
+ np.stack([0, 0, 0, 1], axis=-1)], axis=1)
163
+ new_world2cam = np.matmul(rot_mat, cam_pose.world_to_camera)
164
+ return CameraPose.from_world_to_camera(new_world2cam)
third_party/COTR/COTR/cameras/capture.py ADDED
@@ -0,0 +1,432 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Capture from a pinhole camera
3
+ Separate the captured content and the camera...
4
+ '''
5
+
6
+ import os
7
+ import time
8
+ import abc
9
+ import copy
10
+
11
+ import cv2
12
+ import torch
13
+ import numpy as np
14
+ import imageio
15
+ import PIL
16
+ from PIL import Image
17
+
18
+ from COTR.cameras.camera_pose import CameraPose, rotate_camera_pose
19
+ from COTR.cameras.pinhole_camera import PinholeCamera, rotate_pinhole_camera, crop_pinhole_camera
20
+ from COTR.utils import debug_utils, utils, constants
21
+ from COTR.utils.utils import Point2D
22
+ from COTR.projector import pcd_projector
23
+ from COTR.utils.constants import MAX_SIZE
24
+ from COTR.utils.utils import CropCamConfig
25
+
26
+
27
+ def crop_center_max_xy(p2d, shape):
28
+ h, w = shape
29
+ crop_x = min(h, w)
30
+ crop_y = crop_x
31
+ start_x = w // 2 - crop_x // 2
32
+ start_y = h // 2 - crop_y // 2
33
+ mask = (p2d.xy[:, 0] > start_x) & (p2d.xy[:, 0] < start_x + crop_x) & (p2d.xy[:, 1] > start_y) & (p2d.xy[:, 1] < start_y + crop_y)
34
+ out_xy = (p2d.xy - [start_x, start_y])[mask]
35
+ out = Point2D(p2d.id_3d[mask], out_xy)
36
+ return out
37
+
38
+
39
+ def crop_center_max(img):
40
+ if isinstance(img, torch.Tensor):
41
+ return crop_center_max_torch(img)
42
+ elif isinstance(img, np.ndarray):
43
+ return crop_center_max_np(img)
44
+ else:
45
+ raise ValueError
46
+
47
+
48
+ def crop_center_max_torch(img):
49
+ if len(img.shape) == 2:
50
+ h, w = img.shape
51
+ elif len(img.shape) == 3:
52
+ c, h, w = img.shape
53
+ elif len(img.shape) == 4:
54
+ b, c, h, w = img.shape
55
+ else:
56
+ raise ValueError
57
+ crop_x = min(h, w)
58
+ crop_y = crop_x
59
+ start_x = w // 2 - crop_x // 2
60
+ start_y = h // 2 - crop_y // 2
61
+ if len(img.shape) == 2:
62
+ return img[start_y:start_y + crop_y, start_x:start_x + crop_x]
63
+ elif len(img.shape) in [3, 4]:
64
+ return img[..., start_y:start_y + crop_y, start_x:start_x + crop_x]
65
+
66
+
67
+ def crop_center_max_np(img, return_starts=False):
68
+ if len(img.shape) == 2:
69
+ h, w = img.shape
70
+ elif len(img.shape) == 3:
71
+ h, w, c = img.shape
72
+ elif len(img.shape) == 4:
73
+ b, h, w, c = img.shape
74
+ else:
75
+ raise ValueError
76
+ crop_x = min(h, w)
77
+ crop_y = crop_x
78
+ start_x = w // 2 - crop_x // 2
79
+ start_y = h // 2 - crop_y // 2
80
+ if len(img.shape) == 2:
81
+ canvas = img[start_y:start_y + crop_y, start_x:start_x + crop_x]
82
+ elif len(img.shape) == 3:
83
+ canvas = img[start_y:start_y + crop_y, start_x:start_x + crop_x, :]
84
+ elif len(img.shape) == 4:
85
+ canvas = img[:, start_y:start_y + crop_y, start_x:start_x + crop_x, :]
86
+ if return_starts:
87
+ return canvas, -start_x, -start_y
88
+ else:
89
+ return canvas
90
+
91
+
92
+ def pad_to_square_np(img, till_divisible_by=1, return_starts=False):
93
+ if len(img.shape) == 2:
94
+ h, w = img.shape
95
+ elif len(img.shape) == 3:
96
+ h, w, c = img.shape
97
+ elif len(img.shape) == 4:
98
+ b, h, w, c = img.shape
99
+ else:
100
+ raise ValueError
101
+ if till_divisible_by == 1:
102
+ size = max(h, w)
103
+ else:
104
+ size = (max(h, w) + till_divisible_by) - (max(h, w) % till_divisible_by)
105
+ start_x = size // 2 - w // 2
106
+ start_y = size // 2 - h // 2
107
+ if len(img.shape) == 2:
108
+ canvas = np.zeros([size, size], dtype=img.dtype)
109
+ canvas[start_y:start_y + h, start_x:start_x + w] = img
110
+ elif len(img.shape) == 3:
111
+ canvas = np.zeros([size, size, c], dtype=img.dtype)
112
+ canvas[start_y:start_y + h, start_x:start_x + w, :] = img
113
+ elif len(img.shape) == 4:
114
+ canvas = np.zeros([b, size, size, c], dtype=img.dtype)
115
+ canvas[:, start_y:start_y + h, start_x:start_x + w, :] = img
116
+ if return_starts:
117
+ return canvas, start_x, start_y
118
+ else:
119
+ return canvas
120
+
121
+
122
+ def stretch_to_square_np(img):
123
+ size = max(*img.shape[:2])
124
+ return np.array(PIL.Image.fromarray(img).resize((size, size), resample=PIL.Image.BILINEAR))
125
+
126
+
127
+ def rotate_image(image, angle, interpolation=cv2.INTER_LINEAR):
128
+ image_center = tuple(np.array(image.shape[1::-1]) / 2)
129
+ rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
130
+ result = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=interpolation)
131
+ return result
132
+
133
+
134
+ def read_array(path):
135
+ '''
136
+ https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
137
+ '''
138
+ with open(path, "rb") as fid:
139
+ width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
140
+ usecols=(0, 1, 2), dtype=int)
141
+ fid.seek(0)
142
+ num_delimiter = 0
143
+ byte = fid.read(1)
144
+ while True:
145
+ if byte == b"&":
146
+ num_delimiter += 1
147
+ if num_delimiter >= 3:
148
+ break
149
+ byte = fid.read(1)
150
+ array = np.fromfile(fid, np.float32)
151
+ array = array.reshape((width, height, channels), order="F")
152
+ return np.transpose(array, (1, 0, 2)).squeeze()
153
+
154
+
155
+ ################ Content ################
156
+
157
+
158
+ class CapturedContent(abc.ABC):
159
+ def __init__(self):
160
+ self._rotation = 0
161
+
162
+ @property
163
+ def rotation(self):
164
+ return self._rotation
165
+
166
+ @rotation.setter
167
+ def rotation(self, rot):
168
+ self._rotation = rot
169
+
170
+
171
+ class CapturedImage(CapturedContent):
172
+ def __init__(self, img_path, crop_cam, pinhole_cam_before=None):
173
+ super(CapturedImage, self).__init__()
174
+ assert os.path.isfile(img_path), 'file does not exist: {0}'.format(img_path)
175
+ self.crop_cam = crop_cam
176
+ self._image = None
177
+ self.img_path = img_path
178
+ self.pinhole_cam_before = pinhole_cam_before
179
+ self._p2d = None
180
+
181
+ def read_image_to_ram(self) -> int:
182
+ # raise NotImplementedError
183
+ assert self._image is None
184
+ _image = self.image
185
+ self._image = _image
186
+ return self._image.nbytes
187
+
188
+ @property
189
+ def image(self):
190
+ if self._image is not None:
191
+ _image = self._image
192
+ else:
193
+ _image = imageio.imread(self.img_path, pilmode='RGB')
194
+ if self.rotation != 0:
195
+ _image = rotate_image(_image, self.rotation)
196
+ if _image.shape[:2] != self.pinhole_cam_before.shape:
197
+ _image = np.array(PIL.Image.fromarray(_image).resize(self.pinhole_cam_before.shape[::-1], resample=PIL.Image.BILINEAR))
198
+ assert _image.shape[:2] == self.pinhole_cam_before.shape
199
+ if self.crop_cam == 'no_crop':
200
+ pass
201
+ elif self.crop_cam == 'crop_center':
202
+ _image = crop_center_max(_image)
203
+ elif self.crop_cam == 'crop_center_and_resize':
204
+ _image = crop_center_max(_image)
205
+ _image = np.array(PIL.Image.fromarray(_image).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
206
+ elif isinstance(self.crop_cam, CropCamConfig):
207
+ assert _image.shape[0] == self.crop_cam.orig_h
208
+ assert _image.shape[1] == self.crop_cam.orig_w
209
+ _image = _image[self.crop_cam.y:self.crop_cam.y + self.crop_cam.h,
210
+ self.crop_cam.x:self.crop_cam.x + self.crop_cam.w, ]
211
+ _image = np.array(PIL.Image.fromarray(_image).resize((self.crop_cam.out_w, self.crop_cam.out_h), resample=PIL.Image.BILINEAR))
212
+ assert _image.shape[:2] == (self.crop_cam.out_h, self.crop_cam.out_w)
213
+ else:
214
+ raise ValueError()
215
+ return _image
216
+
217
+ @property
218
+ def p2d(self):
219
+ if self._p2d is None:
220
+ return self._p2d
221
+ else:
222
+ _p2d = self._p2d
223
+ if self.crop_cam == 'no_crop':
224
+ pass
225
+ elif self.crop_cam == 'crop_center':
226
+ _p2d = crop_center_max_xy(_p2d, self.pinhole_cam_before.shape)
227
+ else:
228
+ raise ValueError()
229
+ return _p2d
230
+
231
+ @p2d.setter
232
+ def p2d(self, value):
233
+ if value is not None:
234
+ assert isinstance(value, Point2D)
235
+ self._p2d = value
236
+
237
+
238
+ class CapturedDepth(CapturedContent):
239
+ def __init__(self, depth_path, crop_cam, pinhole_cam_before=None):
240
+ super(CapturedDepth, self).__init__()
241
+ if not depth_path.endswith('dummy'):
242
+ assert os.path.isfile(depth_path), 'file does not exist: {0}'.format(depth_path)
243
+ self.crop_cam = crop_cam
244
+ self._depth = None
245
+ self.depth_path = depth_path
246
+ self.pinhole_cam_before = pinhole_cam_before
247
+
248
+ def read_depth(self):
249
+ import tables
250
+ if self.depth_path.endswith('dummy'):
251
+ image_path = self.depth_path[:-5]
252
+ w, h = Image.open(image_path).size
253
+ _depth = np.zeros([h, w], dtype=np.float32)
254
+ elif self.depth_path.endswith('.h5'):
255
+ depth_h5 = tables.open_file(self.depth_path, mode='r')
256
+ _depth = np.array(depth_h5.root.depth)
257
+ depth_h5.close()
258
+ else:
259
+ raise ValueError
260
+ return _depth.astype(np.float32)
261
+
262
+ def read_depth_to_ram(self) -> int:
263
+ # raise NotImplementedError
264
+ assert self._depth is None
265
+ _depth = self.depth_map
266
+ self._depth = _depth
267
+ return self._depth.nbytes
268
+
269
+ @property
270
+ def depth_map(self):
271
+ if self._depth is not None:
272
+ _depth = self._depth
273
+ else:
274
+ _depth = self.read_depth()
275
+ if self.rotation != 0:
276
+ _depth = rotate_image(_depth, self.rotation, interpolation=cv2.INTER_NEAREST)
277
+ if _depth.shape != self.pinhole_cam_before.shape:
278
+ _depth = np.array(PIL.Image.fromarray(_depth).resize(self.pinhole_cam_before.shape[::-1], resample=PIL.Image.NEAREST))
279
+ assert _depth.shape[:2] == self.pinhole_cam_before.shape
280
+ if self.crop_cam == 'no_crop':
281
+ pass
282
+ elif self.crop_cam == 'crop_center':
283
+ _depth = crop_center_max(_depth)
284
+ elif self.crop_cam == 'crop_center_and_resize':
285
+ _depth = crop_center_max(_depth)
286
+ _depth = np.array(PIL.Image.fromarray(_depth).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.NEAREST))
287
+ elif isinstance(self.crop_cam, CropCamConfig):
288
+ assert _depth.shape[0] == self.crop_cam.orig_h
289
+ assert _depth.shape[1] == self.crop_cam.orig_w
290
+ _depth = _depth[self.crop_cam.y:self.crop_cam.y + self.crop_cam.h,
291
+ self.crop_cam.x:self.crop_cam.x + self.crop_cam.w, ]
292
+ _depth = np.array(PIL.Image.fromarray(_depth).resize((self.crop_cam.out_w, self.crop_cam.out_h), resample=PIL.Image.NEAREST))
293
+ assert _depth.shape[:2] == (self.crop_cam.out_h, self.crop_cam.out_w)
294
+ else:
295
+ raise ValueError()
296
+ assert (_depth >= 0).all()
297
+ return _depth
298
+
299
+
300
+ ################ Pinhole Capture ################
301
+ class BasePinholeCapture():
302
+ def __init__(self, pinhole_cam, cam_pose, crop_cam):
303
+ self.crop_cam = crop_cam
304
+ self.cam_pose = cam_pose
305
+ # modify the camera instrinsics
306
+ self.pinhole_cam = crop_pinhole_camera(pinhole_cam, crop_cam)
307
+ self.pinhole_cam_before = pinhole_cam
308
+
309
+ def __str__(self):
310
+ string = 'pinhole camera: {0}\ncamera pose: {1}'.format(self.pinhole_cam, self.cam_pose)
311
+ return string
312
+
313
+ @property
314
+ def intrinsic_mat(self):
315
+ return self.pinhole_cam.intrinsic_mat
316
+
317
+ @property
318
+ def extrinsic_mat(self):
319
+ return self.cam_pose.extrinsic_mat
320
+
321
+ @property
322
+ def shape(self):
323
+ return self.pinhole_cam.shape
324
+
325
+ @property
326
+ def size(self):
327
+ return self.shape
328
+
329
+ @property
330
+ def mvp_mat(self):
331
+ '''
332
+ model-view-projection matrix (naming from opengl)
333
+ '''
334
+ return np.matmul(self.pinhole_cam.intrinsic_mat, self.cam_pose.world_to_camera_3x4)
335
+
336
+
337
+ class RGBPinholeCapture(BasePinholeCapture):
338
+ def __init__(self, img_path, pinhole_cam, cam_pose, crop_cam):
339
+ BasePinholeCapture.__init__(self, pinhole_cam, cam_pose, crop_cam)
340
+ self.captured_image = CapturedImage(img_path, crop_cam, self.pinhole_cam_before)
341
+
342
+ def read_image_to_ram(self) -> int:
343
+ return self.captured_image.read_image_to_ram()
344
+
345
+ @property
346
+ def img_path(self):
347
+ return self.captured_image.img_path
348
+
349
+ @property
350
+ def image(self):
351
+ _image = self.captured_image.image
352
+ assert _image.shape[0:2] == self.pinhole_cam.shape, 'image shape: {0}, pinhole camera: {1}'.format(_image.shape, self.pinhole_cam)
353
+ return _image
354
+
355
+ @property
356
+ def seq_id(self):
357
+ return os.path.dirname(self.captured_image.img_path)
358
+
359
+ @property
360
+ def p2d(self):
361
+ return self.captured_image.p2d
362
+
363
+ @p2d.setter
364
+ def p2d(self, value):
365
+ self.captured_image.p2d = value
366
+
367
+
368
+ class DepthPinholeCapture(BasePinholeCapture):
369
+ def __init__(self, depth_path, pinhole_cam, cam_pose, crop_cam):
370
+ BasePinholeCapture.__init__(self, pinhole_cam, cam_pose, crop_cam)
371
+ self.captured_depth = CapturedDepth(depth_path, crop_cam, self.pinhole_cam_before)
372
+
373
+ def read_depth_to_ram(self) -> int:
374
+ return self.captured_depth.read_depth_to_ram()
375
+
376
+ @property
377
+ def depth_path(self):
378
+ return self.captured_depth.depth_path
379
+
380
+ @property
381
+ def depth_map(self):
382
+ _depth = self.captured_depth.depth_map
383
+ # if self.pinhole_cam.shape != _depth.shape:
384
+ # _depth = misc.imresize(_depth, self.pinhole_cam.shape, interp='nearest', mode='F')
385
+ assert (_depth >= 0).all()
386
+ return _depth
387
+
388
+ @property
389
+ def point_cloud_world(self):
390
+ return self.get_point_cloud_world_from_depth(feat_map=None)
391
+
392
+ def get_point_cloud_world_from_depth(self, feat_map=None):
393
+ _pcd = pcd_projector.PointCloudProjector.img_2d_to_pcd_3d_np(self.depth_map, self.pinhole_cam.intrinsic_mat, img=feat_map, motion=self.cam_pose.camera_to_world).astype(constants.DEFAULT_PRECISION)
394
+ return _pcd
395
+
396
+
397
+ class RGBDPinholeCapture(RGBPinholeCapture, DepthPinholeCapture):
398
+ def __init__(self, img_path, depth_path, pinhole_cam, cam_pose, crop_cam):
399
+ RGBPinholeCapture.__init__(self, img_path, pinhole_cam, cam_pose, crop_cam)
400
+ DepthPinholeCapture.__init__(self, depth_path, pinhole_cam, cam_pose, crop_cam)
401
+
402
+ @property
403
+ def point_cloud_w_rgb_world(self):
404
+ return self.get_point_cloud_world_from_depth(feat_map=self.image)
405
+
406
+
407
+ def rotate_capture(cap, rot):
408
+ if rot == 0:
409
+ return copy.deepcopy(cap)
410
+ else:
411
+ rot_pose = rotate_camera_pose(cap.cam_pose, rot)
412
+ rot_cap = copy.deepcopy(cap)
413
+ rot_cap.cam_pose = rot_pose
414
+ if hasattr(rot_cap, 'captured_image'):
415
+ rot_cap.captured_image.rotation = rot
416
+ if hasattr(rot_cap, 'captured_depth'):
417
+ rot_cap.captured_depth.rotation = rot
418
+ return rot_cap
419
+
420
+
421
+ def crop_capture(cap, crop_cam):
422
+ if isinstance(cap, RGBDPinholeCapture):
423
+ cropped_cap = RGBDPinholeCapture(cap.img_path, cap.depth_path, cap.pinhole_cam, cap.cam_pose, crop_cam)
424
+ elif isinstance(cap, RGBPinholeCapture):
425
+ cropped_cap = RGBPinholeCapture(cap.img_path, cap.pinhole_cam, cap.cam_pose, crop_cam)
426
+ else:
427
+ raise ValueError
428
+ if hasattr(cropped_cap, 'captured_image'):
429
+ cropped_cap.captured_image.rotation = cap.captured_image.rotation
430
+ if hasattr(cropped_cap, 'captured_depth'):
431
+ cropped_cap.captured_depth.rotation = cap.captured_depth.rotation
432
+ return cropped_cap
third_party/COTR/COTR/cameras/pinhole_camera.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Static pinhole camera
3
+ """
4
+
5
+ import copy
6
+
7
+ import numpy as np
8
+
9
+ from COTR.utils import constants
10
+ from COTR.utils.constants import MAX_SIZE
11
+ from COTR.utils.utils import CropCamConfig
12
+
13
+
14
+ class PinholeCamera():
15
+ def __init__(self, width, height, fx, fy, cx, cy):
16
+ self.width = int(width)
17
+ self.height = int(height)
18
+ self.fx = fx
19
+ self.fy = fy
20
+ self.cx = cx
21
+ self.cy = cy
22
+
23
+ def __str__(self):
24
+ string = 'width: {0}, height: {1}, fx: {2}, fy: {3}, cx: {4}, cy: {5}'.format(self.width, self.height, self.fx, self.fy, self.cx, self.cy)
25
+ return string
26
+
27
+ @property
28
+ def shape(self):
29
+ return (self.height, self.width)
30
+
31
+ @property
32
+ def intrinsic_mat(self):
33
+ mat = np.array([[self.fx, 0.0, self.cx],
34
+ [0.0, self.fy, self.cy],
35
+ [0.0, 0.0, 1.0]], dtype=constants.DEFAULT_PRECISION)
36
+ return mat
37
+
38
+
39
+ def rotate_pinhole_camera(cam, rot):
40
+ assert 0, 'TODO: Camera should stay the same while rotation'
41
+ assert rot in [0, 90, 180, 270], 'only support 0/90/180/270 degrees rotation'
42
+ if rot in [0, 180]:
43
+ return copy.deepcopy(cam)
44
+ elif rot in [90, 270]:
45
+ return PinholeCamera(width=cam.height, height=cam.width, fx=cam.fy, fy=cam.fx, cx=cam.cy, cy=cam.cx)
46
+ else:
47
+ raise NotImplementedError
48
+
49
+
50
+ def crop_pinhole_camera(pinhole_cam, crop_cam):
51
+ if crop_cam == 'no_crop':
52
+ cropped_pinhole_cam = pinhole_cam
53
+ elif crop_cam == 'crop_center':
54
+ _h = _w = min(*pinhole_cam.shape)
55
+ _cx = _cy = _h / 2
56
+ cropped_pinhole_cam = PinholeCamera(_w, _h, pinhole_cam.fx, pinhole_cam.fy, _cx, _cy)
57
+ elif crop_cam == 'crop_center_and_resize':
58
+ _h = _w = MAX_SIZE
59
+ _cx = _cy = MAX_SIZE / 2
60
+ scale = MAX_SIZE / min(*pinhole_cam.shape)
61
+ cropped_pinhole_cam = PinholeCamera(_w, _h, pinhole_cam.fx * scale, pinhole_cam.fy * scale, _cx, _cy)
62
+ elif isinstance(crop_cam, CropCamConfig):
63
+ scale = crop_cam.out_h / crop_cam.h
64
+ cropped_pinhole_cam = PinholeCamera(crop_cam.out_w,
65
+ crop_cam.out_h,
66
+ pinhole_cam.fx * scale,
67
+ pinhole_cam.fy * scale,
68
+ (pinhole_cam.cx - crop_cam.x) * scale,
69
+ (pinhole_cam.cy - crop_cam.y) * scale
70
+ )
71
+ else:
72
+ raise ValueError
73
+ return cropped_pinhole_cam
third_party/COTR/COTR/datasets/colmap_helper.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ assert sys.version_info >= (3, 7), 'ordered dict is required'
3
+ import os
4
+ import re
5
+ from collections import namedtuple
6
+ import json
7
+
8
+ import numpy as np
9
+ from tqdm import tqdm
10
+
11
+ from COTR.utils import debug_utils
12
+ from COTR.cameras.pinhole_camera import PinholeCamera
13
+ from COTR.cameras.camera_pose import CameraPose
14
+ from COTR.cameras.capture import RGBPinholeCapture, RGBDPinholeCapture
15
+ from COTR.cameras import capture
16
+ from COTR.transformations import transformations
17
+ from COTR.transformations.transform_basics import Translation, Rotation
18
+ from COTR.sfm_scenes import sfm_scenes
19
+ from COTR.global_configs import dataset_config
20
+ from COTR.utils.utils import Point2D, Point3D
21
+
22
+ ImageMeta = namedtuple('ImageMeta', ['image_id', 'r', 't', 'camera_id', 'image_path', 'point3d_id', 'p2d'])
23
+ COVISIBILITY_CHECK = False
24
+ LOAD_PCD = False
25
+
26
+
27
+ class ColmapAsciiReader():
28
+ def __init__(self):
29
+ pass
30
+
31
+ @classmethod
32
+ def read_sfm_scene(cls, scene_dir, images_dir, crop_cam):
33
+ point_cloud_path = os.path.join(scene_dir, 'points3D.txt')
34
+ cameras_path = os.path.join(scene_dir, 'cameras.txt')
35
+ images_path = os.path.join(scene_dir, 'images.txt')
36
+ captures = cls.read_captures(images_path, cameras_path, images_dir, crop_cam)
37
+ if LOAD_PCD:
38
+ point_cloud = cls.read_point_cloud(point_cloud_path)
39
+ else:
40
+ point_cloud = None
41
+ sfm_scene = sfm_scenes.SfmScene(captures, point_cloud)
42
+ return sfm_scene
43
+
44
+ @staticmethod
45
+ def read_point_cloud(points_txt_path):
46
+ with open(points_txt_path, "r") as fid:
47
+ line = fid.readline()
48
+ assert line == '# 3D point list with one line of data per point:\n'
49
+ line = fid.readline()
50
+ assert line == '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n'
51
+ line = fid.readline()
52
+ assert re.search('^# Number of points: \d+, mean track length: [-+]?\d*\.\d+|\d+\n$', line)
53
+ num_points, mean_track_length = re.findall(r"[-+]?\d*\.\d+|\d+", line)
54
+ num_points = int(num_points)
55
+ mean_track_length = float(mean_track_length)
56
+
57
+ xyz = np.zeros((num_points, 3), dtype=np.float32)
58
+ rgb = np.zeros((num_points, 3), dtype=np.float32)
59
+ if COVISIBILITY_CHECK:
60
+ point_meta = {}
61
+
62
+ for i in tqdm(range(num_points), desc='reading point cloud'):
63
+ elems = fid.readline().split()
64
+ xyz[i] = list(map(float, elems[1:4]))
65
+ rgb[i] = list(map(int, elems[4:7]))
66
+ if COVISIBILITY_CHECK:
67
+ point_id = int(elems[0])
68
+ image_ids = np.array(tuple(map(int, elems[8::2])))
69
+ point_meta[point_id] = Point3D(id=point_id,
70
+ arr_idx=i,
71
+ image_ids=image_ids)
72
+ pcd = np.concatenate([xyz, rgb], axis=1)
73
+ if COVISIBILITY_CHECK:
74
+ return pcd, point_meta
75
+ else:
76
+ return pcd
77
+
78
+ @classmethod
79
+ def read_captures(cls, images_txt_path, cameras_txt_path, images_dir, crop_cam):
80
+ captures = []
81
+ cameras = cls.read_cameras(cameras_txt_path)
82
+ images_meta = cls.read_images_meta(images_txt_path, images_dir)
83
+ for key in images_meta.keys():
84
+ cur_cam_id = images_meta[key].camera_id
85
+ cur_cam = cameras[cur_cam_id]
86
+ cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r)
87
+ cur_image_path = images_meta[key].image_path
88
+ cap = RGBPinholeCapture(cur_image_path, cur_cam, cur_camera_pose, crop_cam)
89
+ captures.append(cap)
90
+ return captures
91
+
92
+ @classmethod
93
+ def read_cameras(cls, cameras_txt_path):
94
+ cameras = {}
95
+ with open(cameras_txt_path, "r") as fid:
96
+ line = fid.readline()
97
+ assert line == '# Camera list with one line of data per camera:\n'
98
+ line = fid.readline()
99
+ assert line == '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n'
100
+ line = fid.readline()
101
+ assert re.search('^# Number of cameras: \d+\n$', line)
102
+ num_cams = int(re.findall(r"[-+]?\d*\.\d+|\d+", line)[0])
103
+
104
+ for _ in tqdm(range(num_cams), desc='reading cameras'):
105
+ elems = fid.readline().split()
106
+ camera_id = int(elems[0])
107
+ camera_type = elems[1]
108
+ if camera_type == "PINHOLE":
109
+ width, height, focal_length_x, focal_length_y, cx, cy = list(map(float, elems[2:8]))
110
+ else:
111
+ raise ValueError('Please rectify the 3D model to pinhole camera.')
112
+ cur_cam = PinholeCamera(width, height, focal_length_x, focal_length_y, cx, cy)
113
+ assert camera_id not in cameras
114
+ cameras[camera_id] = cur_cam
115
+ return cameras
116
+
117
+ @classmethod
118
+ def read_images_meta(cls, images_txt_path, images_dir):
119
+ images_meta = {}
120
+ with open(images_txt_path, "r") as fid:
121
+ line = fid.readline()
122
+ assert line == '# Image list with two lines of data per image:\n'
123
+ line = fid.readline()
124
+ assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n'
125
+ line = fid.readline()
126
+ assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n'
127
+ line = fid.readline()
128
+ assert re.search('^# Number of images: \d+, mean observations per image: [-+]?\d*\.\d+|\d+\n$', line)
129
+ num_images, mean_ob_per_img = re.findall(r"[-+]?\d*\.\d+|\d+", line)
130
+ num_images = int(num_images)
131
+ mean_ob_per_img = float(mean_ob_per_img)
132
+
133
+ for _ in tqdm(range(num_images), desc='reading images meta'):
134
+ elems = fid.readline().split()
135
+ assert len(elems) == 10
136
+
137
+ image_path = os.path.join(images_dir, elems[9])
138
+ assert os.path.isfile(image_path)
139
+ image_id = int(elems[0])
140
+ qw, qx, qy, qz, tx, ty, tz = list(map(float, elems[1:8]))
141
+ t = Translation(np.array([tx, ty, tz], dtype=np.float32))
142
+ r = Rotation(np.array([qw, qx, qy, qz], dtype=np.float32))
143
+ camera_id = int(elems[8])
144
+ assert image_id not in images_meta
145
+
146
+ line = fid.readline()
147
+ if COVISIBILITY_CHECK:
148
+ elems = line.split()
149
+ elems = list(map(float, elems))
150
+ elems = np.array(elems).reshape(-1, 3)
151
+ point3d_id = set(elems[elems[:, 2] != -1][:, 2].astype(np.int))
152
+ point3d_id = np.sort(np.array(list(point3d_id)))
153
+ xyi = elems[elems[:, 2] != -1]
154
+ xy = xyi[:, :2]
155
+ idx = xyi[:, 2].astype(np.int)
156
+ p2d = Point2D(idx, xy)
157
+ else:
158
+ point3d_id = None
159
+ p2d = None
160
+
161
+ images_meta[image_id] = ImageMeta(image_id, r, t, camera_id, image_path, point3d_id, p2d)
162
+ return images_meta
163
+
164
+
165
+ class ColmapWithDepthAsciiReader(ColmapAsciiReader):
166
+ '''
167
+ Not all images have usable depth estimate from colmap.
168
+ A valid list is needed.
169
+ '''
170
+
171
+ @classmethod
172
+ def read_sfm_scene(cls, scene_dir, images_dir, depth_dir, crop_cam):
173
+ point_cloud_path = os.path.join(scene_dir, 'points3D.txt')
174
+ cameras_path = os.path.join(scene_dir, 'cameras.txt')
175
+ images_path = os.path.join(scene_dir, 'images.txt')
176
+ captures = cls.read_captures(images_path, cameras_path, images_dir, depth_dir, crop_cam)
177
+ if LOAD_PCD:
178
+ point_cloud = cls.read_point_cloud(point_cloud_path)
179
+ else:
180
+ point_cloud = None
181
+ sfm_scene = sfm_scenes.SfmScene(captures, point_cloud)
182
+ return sfm_scene
183
+
184
+ @classmethod
185
+ def read_sfm_scene_given_valid_list_path(cls, scene_dir, images_dir, depth_dir, valid_list_json_path, crop_cam):
186
+ point_cloud_path = os.path.join(scene_dir, 'points3D.txt')
187
+ cameras_path = os.path.join(scene_dir, 'cameras.txt')
188
+ images_path = os.path.join(scene_dir, 'images.txt')
189
+ valid_list = cls.read_valid_list(valid_list_json_path)
190
+ captures = cls.read_captures_with_depth_given_valid_list(images_path, cameras_path, images_dir, depth_dir, valid_list, crop_cam)
191
+ if LOAD_PCD:
192
+ point_cloud = cls.read_point_cloud(point_cloud_path)
193
+ else:
194
+ point_cloud = None
195
+ sfm_scene = sfm_scenes.SfmScene(captures, point_cloud)
196
+ return sfm_scene
197
+
198
+ @classmethod
199
+ def read_captures(cls, images_txt_path, cameras_txt_path, images_dir, depth_dir, crop_cam):
200
+ captures = []
201
+ cameras = cls.read_cameras(cameras_txt_path)
202
+ images_meta = cls.read_images_meta(images_txt_path, images_dir)
203
+ for key in images_meta.keys():
204
+ cur_cam_id = images_meta[key].camera_id
205
+ cur_cam = cameras[cur_cam_id]
206
+ cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r)
207
+ cur_image_path = images_meta[key].image_path
208
+ try:
209
+ cur_depth_path = cls.image_path_2_depth_path(cur_image_path[len(images_dir) + 1:], depth_dir)
210
+ except:
211
+ print('{0} does not have depth at {1}'.format(cur_image_path, depth_dir))
212
+ # TODO
213
+ # continue
214
+ # exec(debug_utils.embed_breakpoint())
215
+ cur_depth_path = f'{cur_image_path}dummy'
216
+
217
+ cap = RGBDPinholeCapture(cur_image_path, cur_depth_path, cur_cam, cur_camera_pose, crop_cam)
218
+ cap.point3d_id = images_meta[key].point3d_id
219
+ cap.p2d = images_meta[key].p2d
220
+ cap.image_id = key
221
+ captures.append(cap)
222
+ return captures
223
+
224
+ @classmethod
225
+ def read_captures_with_depth_given_valid_list(cls, images_txt_path, cameras_txt_path, images_dir, depth_dir, valid_list, crop_cam):
226
+ captures = []
227
+ cameras = cls.read_cameras(cameras_txt_path)
228
+ images_meta = cls.read_images_meta_given_valid_list(images_txt_path, images_dir, valid_list)
229
+ for key in images_meta.keys():
230
+ cur_cam_id = images_meta[key].camera_id
231
+ cur_cam = cameras[cur_cam_id]
232
+ cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r)
233
+ cur_image_path = images_meta[key].image_path
234
+ try:
235
+ cur_depth_path = cls.image_path_2_depth_path(cur_image_path, depth_dir)
236
+ except:
237
+ print('{0} does not have depth at {1}'.format(cur_image_path, depth_dir))
238
+ continue
239
+ cap = RGBDPinholeCapture(cur_image_path, cur_depth_path, cur_cam, cur_camera_pose, crop_cam)
240
+ cap.point3d_id = images_meta[key].point3d_id
241
+ cap.p2d = images_meta[key].p2d
242
+ cap.image_id = key
243
+ captures.append(cap)
244
+ return captures
245
+
246
+ @classmethod
247
+ def read_images_meta_given_valid_list(cls, images_txt_path, images_dir, valid_list):
248
+ images_meta = {}
249
+ with open(images_txt_path, "r") as fid:
250
+ line = fid.readline()
251
+ assert line == '# Image list with two lines of data per image:\n'
252
+ line = fid.readline()
253
+ assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n'
254
+ line = fid.readline()
255
+ assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n'
256
+ line = fid.readline()
257
+ assert re.search('^# Number of images: \d+, mean observations per image:[-+]?\d*\.\d+|\d+\n$', line), line
258
+ num_images, mean_ob_per_img = re.findall(r"[-+]?\d*\.\d+|\d+", line)
259
+ num_images = int(num_images)
260
+ mean_ob_per_img = float(mean_ob_per_img)
261
+
262
+ for _ in tqdm(range(num_images), desc='reading images meta'):
263
+ elems = fid.readline().split()
264
+ assert len(elems) == 10
265
+ line = fid.readline()
266
+ image_path = os.path.join(images_dir, elems[9])
267
+ prefix = os.path.abspath(os.path.join(image_path, '../../../../')) + '/'
268
+ rel_image_path = image_path.replace(prefix, '')
269
+ if rel_image_path not in valid_list:
270
+ continue
271
+ assert os.path.isfile(image_path), '{0} is not existing'.format(image_path)
272
+ image_id = int(elems[0])
273
+ qw, qx, qy, qz, tx, ty, tz = list(map(float, elems[1:8]))
274
+ t = Translation(np.array([tx, ty, tz], dtype=np.float32))
275
+ r = Rotation(np.array([qw, qx, qy, qz], dtype=np.float32))
276
+ camera_id = int(elems[8])
277
+ assert image_id not in images_meta
278
+
279
+ if COVISIBILITY_CHECK:
280
+ elems = line.split()
281
+ elems = list(map(float, elems))
282
+ elems = np.array(elems).reshape(-1, 3)
283
+ point3d_id = set(elems[elems[:, 2] != -1][:, 2].astype(np.int))
284
+ point3d_id = np.sort(np.array(list(point3d_id)))
285
+ xyi = elems[elems[:, 2] != -1]
286
+ xy = xyi[:, :2]
287
+ idx = xyi[:, 2].astype(np.int)
288
+ p2d = Point2D(idx, xy)
289
+ else:
290
+ point3d_id = None
291
+ p2d = None
292
+ images_meta[image_id] = ImageMeta(image_id, r, t, camera_id, image_path, point3d_id, p2d)
293
+ return images_meta
294
+
295
+ @classmethod
296
+ def read_valid_list(cls, valid_list_json_path):
297
+ assert os.path.isfile(valid_list_json_path), valid_list_json_path
298
+ with open(valid_list_json_path, 'r') as f:
299
+ valid_list = json.load(f)
300
+ assert len(valid_list) == len(set(valid_list))
301
+ return set(valid_list)
302
+
303
+ @classmethod
304
+ def image_path_2_depth_path(cls, image_path, depth_dir):
305
+ depth_file = os.path.splitext(os.path.basename(image_path))[0] + '.h5'
306
+ depth_path = os.path.join(depth_dir, depth_file)
307
+ if not os.path.isfile(depth_path):
308
+ # depth_file = image_path + '.photometric.bin'
309
+ depth_file = image_path + '.geometric.bin'
310
+ depth_path = os.path.join(depth_dir, depth_file)
311
+ assert os.path.isfile(depth_path), '{0} is not file'.format(depth_path)
312
+ return depth_path
third_party/COTR/COTR/datasets/cotr_dataset.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ COTR dataset
3
+ '''
4
+
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torchvision.transforms import functional as tvtf
10
+ from torch.utils import data
11
+
12
+ from COTR.datasets import megadepth_dataset
13
+ from COTR.utils import debug_utils, utils, constants
14
+ from COTR.projector import pcd_projector
15
+ from COTR.cameras import capture
16
+ from COTR.utils.utils import CropCamConfig
17
+ from COTR.inference import inference_helper
18
+ from COTR.inference.inference_helper import two_images_side_by_side
19
+
20
+
21
+ class COTRDataset(data.Dataset):
22
+ def __init__(self, opt, dataset_type: str):
23
+ assert dataset_type in ['train', 'val', 'test']
24
+ assert len(opt.scenes_name_list) > 0
25
+ self.opt = opt
26
+ self.dataset_type = dataset_type
27
+ self.sfm_dataset = megadepth_dataset.MegadepthDataset(opt, dataset_type)
28
+
29
+ self.kp_pool = opt.kp_pool
30
+ self.num_kp = opt.num_kp
31
+ self.bidirectional = opt.bidirectional
32
+ self.need_rotation = opt.need_rotation
33
+ self.max_rotation = opt.max_rotation
34
+ self.rotation_chance = opt.rotation_chance
35
+
36
+ def _trim_corrs(self, in_corrs):
37
+ length = in_corrs.shape[0]
38
+ if length >= self.num_kp:
39
+ mask = np.random.choice(length, self.num_kp)
40
+ return in_corrs[mask]
41
+ else:
42
+ mask = np.random.choice(length, self.num_kp - length)
43
+ return np.concatenate([in_corrs, in_corrs[mask]], axis=0)
44
+
45
+ def __len__(self):
46
+ if self.dataset_type == 'val':
47
+ return min(1000, self.sfm_dataset.num_queries)
48
+ else:
49
+ return self.sfm_dataset.num_queries
50
+
51
+ def augment_with_rotation(self, query_cap, nn_cap):
52
+ if random.random() < self.rotation_chance:
53
+ theta = np.random.uniform(low=-1, high=1) * self.max_rotation
54
+ query_cap = capture.rotate_capture(query_cap, theta)
55
+ if random.random() < self.rotation_chance:
56
+ theta = np.random.uniform(low=-1, high=1) * self.max_rotation
57
+ nn_cap = capture.rotate_capture(nn_cap, theta)
58
+ return query_cap, nn_cap
59
+
60
+ def __getitem__(self, index):
61
+ assert self.opt.k_size == 1
62
+ query_cap, nn_caps = self.sfm_dataset.get_query_with_knn(index)
63
+ nn_cap = nn_caps[0]
64
+
65
+ if self.need_rotation:
66
+ query_cap, nn_cap = self.augment_with_rotation(query_cap, nn_cap)
67
+
68
+ nn_keypoints_y, nn_keypoints_x = np.where(nn_cap.depth_map > 0)
69
+ nn_keypoints_y = nn_keypoints_y[..., None]
70
+ nn_keypoints_x = nn_keypoints_x[..., None]
71
+ nn_keypoints_z = nn_cap.depth_map[np.floor(nn_keypoints_y).astype('int'), np.floor(nn_keypoints_x).astype('int')]
72
+ nn_keypoints_xy = np.concatenate([nn_keypoints_x, nn_keypoints_y], axis=1)
73
+ nn_keypoints_3d_world, valid_index_1 = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(nn_keypoints_xy, nn_keypoints_z, nn_cap.pinhole_cam.intrinsic_mat, motion=nn_cap.cam_pose.camera_to_world, return_index=True)
74
+
75
+ query_keypoints_xyz, valid_index_2 = pcd_projector.PointCloudProjector.pcd_3d_to_pcd_2d_np(
76
+ nn_keypoints_3d_world,
77
+ query_cap.pinhole_cam.intrinsic_mat,
78
+ query_cap.cam_pose.world_to_camera[0:3, :],
79
+ query_cap.image.shape[:2],
80
+ keep_z=True,
81
+ crop=True,
82
+ filter_neg=True,
83
+ norm_coord=False,
84
+ return_index=True,
85
+ )
86
+ query_keypoints_xy = query_keypoints_xyz[:, 0:2]
87
+ query_keypoints_z_proj = query_keypoints_xyz[:, 2:3]
88
+ query_keypoints_z = query_cap.depth_map[np.floor(query_keypoints_xy[:, 1:2]).astype('int'), np.floor(query_keypoints_xy[:, 0:1]).astype('int')]
89
+ mask = (abs(query_keypoints_z - query_keypoints_z_proj) < 0.5)[:, 0]
90
+ query_keypoints_xy = query_keypoints_xy[mask]
91
+
92
+ if query_keypoints_xy.shape[0] < self.num_kp:
93
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
94
+
95
+ nn_keypoints_xy = nn_keypoints_xy[valid_index_1][valid_index_2][mask]
96
+ assert nn_keypoints_xy.shape == query_keypoints_xy.shape
97
+ corrs = np.concatenate([query_keypoints_xy, nn_keypoints_xy], axis=1)
98
+ corrs = self._trim_corrs(corrs)
99
+ # flip augmentation
100
+ if np.random.uniform() < 0.5:
101
+ corrs[:, 0] = constants.MAX_SIZE - 1 - corrs[:, 0]
102
+ corrs[:, 2] = constants.MAX_SIZE - 1 - corrs[:, 2]
103
+ sbs_img = two_images_side_by_side(np.fliplr(query_cap.image), np.fliplr(nn_cap.image))
104
+ else:
105
+ sbs_img = two_images_side_by_side(query_cap.image, nn_cap.image)
106
+ corrs[:, 2] += constants.MAX_SIZE
107
+ corrs /= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE])
108
+ assert (0.0 <= corrs[:, 0]).all() and (corrs[:, 0] <= 0.5).all()
109
+ assert (0.0 <= corrs[:, 1]).all() and (corrs[:, 1] <= 1.0).all()
110
+ assert (0.5 <= corrs[:, 2]).all() and (corrs[:, 2] <= 1.0).all()
111
+ assert (0.0 <= corrs[:, 3]).all() and (corrs[:, 3] <= 1.0).all()
112
+ out = {
113
+ 'image': tvtf.normalize(tvtf.to_tensor(sbs_img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
114
+ 'corrs': torch.from_numpy(corrs).float(),
115
+ }
116
+ if self.bidirectional:
117
+ out['queries'] = torch.from_numpy(np.concatenate([corrs[:, :2], corrs[:, 2:]], axis=0)).float()
118
+ out['targets'] = torch.from_numpy(np.concatenate([corrs[:, 2:], corrs[:, :2]], axis=0)).float()
119
+ else:
120
+ out['queries'] = torch.from_numpy(corrs[:, :2]).float()
121
+ out['targets'] = torch.from_numpy(corrs[:, 2:]).float()
122
+ return out
123
+
124
+
125
+ class COTRZoomDataset(COTRDataset):
126
+ def __init__(self, opt, dataset_type: str):
127
+ assert opt.crop_cam in ['no_crop', 'crop_center']
128
+ assert opt.use_ram == False
129
+ super().__init__(opt, dataset_type)
130
+ self.zoom_start = opt.zoom_start
131
+ self.zoom_end = opt.zoom_end
132
+ self.zoom_levels = opt.zoom_levels
133
+ self.zoom_jitter = opt.zoom_jitter
134
+ self.zooms = np.logspace(np.log10(opt.zoom_start),
135
+ np.log10(opt.zoom_end),
136
+ num=opt.zoom_levels)
137
+
138
+ def get_corrs(self, from_cap, to_cap, reduced_size=None):
139
+ from_y, from_x = np.where(from_cap.depth_map > 0)
140
+ from_y, from_x = from_y[..., None], from_x[..., None]
141
+ if reduced_size is not None:
142
+ filter_idx = np.random.choice(from_y.shape[0], reduced_size, replace=False)
143
+ from_y, from_x = from_y[filter_idx], from_x[filter_idx]
144
+ from_z = from_cap.depth_map[np.floor(from_y).astype('int'), np.floor(from_x).astype('int')]
145
+ from_xy = np.concatenate([from_x, from_y], axis=1)
146
+ from_3d_world, valid_index_1 = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(from_xy, from_z, from_cap.pinhole_cam.intrinsic_mat, motion=from_cap.cam_pose.camera_to_world, return_index=True)
147
+
148
+ to_xyz, valid_index_2 = pcd_projector.PointCloudProjector.pcd_3d_to_pcd_2d_np(
149
+ from_3d_world,
150
+ to_cap.pinhole_cam.intrinsic_mat,
151
+ to_cap.cam_pose.world_to_camera[0:3, :],
152
+ to_cap.image.shape[:2],
153
+ keep_z=True,
154
+ crop=True,
155
+ filter_neg=True,
156
+ norm_coord=False,
157
+ return_index=True,
158
+ )
159
+
160
+ to_xy = to_xyz[:, 0:2]
161
+ to_z_proj = to_xyz[:, 2:3]
162
+ to_z = to_cap.depth_map[np.floor(to_xy[:, 1:2]).astype('int'), np.floor(to_xy[:, 0:1]).astype('int')]
163
+ mask = (abs(to_z - to_z_proj) < 0.5)[:, 0]
164
+ if mask.sum() > 0:
165
+ return np.concatenate([from_xy[valid_index_1][valid_index_2][mask], to_xy[mask]], axis=1)
166
+ else:
167
+ return None
168
+
169
+ def get_seed_corr(self, from_cap, to_cap, max_try=100):
170
+ seed_corr = self.get_corrs(from_cap, to_cap, reduced_size=max_try)
171
+ if seed_corr is None:
172
+ return None
173
+ shuffle = np.random.permutation(seed_corr.shape[0])
174
+ seed_corr = np.take(seed_corr, shuffle, axis=0)
175
+ return seed_corr[0]
176
+
177
+ def get_zoomed_cap(self, cap, pos, scale, jitter):
178
+ patch = inference_helper.get_patch_centered_at(cap.image, pos, scale=scale, return_content=False)
179
+ patch = inference_helper.get_patch_centered_at(cap.image,
180
+ pos + np.array([patch.w, patch.h]) * np.random.uniform(-jitter, jitter, 2),
181
+ scale=scale,
182
+ return_content=False)
183
+ zoom_config = CropCamConfig(x=patch.x,
184
+ y=patch.y,
185
+ w=patch.w,
186
+ h=patch.h,
187
+ out_w=constants.MAX_SIZE,
188
+ out_h=constants.MAX_SIZE,
189
+ orig_w=cap.shape[1],
190
+ orig_h=cap.shape[0])
191
+ zoom_cap = capture.crop_capture(cap, zoom_config)
192
+ return zoom_cap
193
+
194
+ def __getitem__(self, index):
195
+ assert self.opt.k_size == 1
196
+ query_cap, nn_caps = self.sfm_dataset.get_query_with_knn(index)
197
+ nn_cap = nn_caps[0]
198
+ if self.need_rotation:
199
+ query_cap, nn_cap = self.augment_with_rotation(query_cap, nn_cap)
200
+
201
+ # find seed
202
+ seed_corr = self.get_seed_corr(nn_cap, query_cap)
203
+ if seed_corr is None:
204
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
205
+
206
+ # crop cap
207
+ s = np.random.choice(self.zooms)
208
+ nn_zoom_cap = self.get_zoomed_cap(nn_cap, seed_corr[:2], s, 0)
209
+ query_zoom_cap = self.get_zoomed_cap(query_cap, seed_corr[2:], s, self.zoom_jitter)
210
+ assert nn_zoom_cap.shape == query_zoom_cap.shape == (constants.MAX_SIZE, constants.MAX_SIZE)
211
+ corrs = self.get_corrs(query_zoom_cap, nn_zoom_cap)
212
+ if corrs is None or corrs.shape[0] < self.num_kp:
213
+ return self.__getitem__(random.randint(0, self.__len__() - 1))
214
+ shuffle = np.random.permutation(corrs.shape[0])
215
+ corrs = np.take(corrs, shuffle, axis=0)
216
+ corrs = self._trim_corrs(corrs)
217
+
218
+ # flip augmentation
219
+ if np.random.uniform() < 0.5:
220
+ corrs[:, 0] = constants.MAX_SIZE - 1 - corrs[:, 0]
221
+ corrs[:, 2] = constants.MAX_SIZE - 1 - corrs[:, 2]
222
+ sbs_img = two_images_side_by_side(np.fliplr(query_zoom_cap.image), np.fliplr(nn_zoom_cap.image))
223
+ else:
224
+ sbs_img = two_images_side_by_side(query_zoom_cap.image, nn_zoom_cap.image)
225
+
226
+ corrs[:, 2] += constants.MAX_SIZE
227
+ corrs /= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE])
228
+ assert (0.0 <= corrs[:, 0]).all() and (corrs[:, 0] <= 0.5).all()
229
+ assert (0.0 <= corrs[:, 1]).all() and (corrs[:, 1] <= 1.0).all()
230
+ assert (0.5 <= corrs[:, 2]).all() and (corrs[:, 2] <= 1.0).all()
231
+ assert (0.0 <= corrs[:, 3]).all() and (corrs[:, 3] <= 1.0).all()
232
+ out = {
233
+ 'image': tvtf.normalize(tvtf.to_tensor(sbs_img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
234
+ 'corrs': torch.from_numpy(corrs).float(),
235
+ }
236
+ if self.bidirectional:
237
+ out['queries'] = torch.from_numpy(np.concatenate([corrs[:, :2], corrs[:, 2:]], axis=0)).float()
238
+ out['targets'] = torch.from_numpy(np.concatenate([corrs[:, 2:], corrs[:, :2]], axis=0)).float()
239
+ else:
240
+ out['queries'] = torch.from_numpy(corrs[:, :2]).float()
241
+ out['targets'] = torch.from_numpy(corrs[:, 2:]).float()
242
+
243
+ return out
third_party/COTR/COTR/datasets/megadepth_dataset.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ dataset specific layer for megadepth
3
+ '''
4
+
5
+ import os
6
+ import json
7
+ import random
8
+ from collections import namedtuple
9
+
10
+ import numpy as np
11
+
12
+ from COTR.datasets import colmap_helper
13
+ from COTR.global_configs import dataset_config
14
+ from COTR.sfm_scenes import knn_search
15
+ from COTR.utils import debug_utils, utils, constants
16
+
17
+ SceneCapIndex = namedtuple('SceneCapIndex', ['scene_index', 'capture_index'])
18
+
19
+
20
+ def prefix_of_img_path_for_magedepth(img_path):
21
+ '''
22
+ get the prefix for image of megadepth dataset
23
+ '''
24
+ prefix = os.path.abspath(os.path.join(img_path, '../../../..')) + '/'
25
+ return prefix
26
+
27
+
28
+ class MegadepthSceneDataBase():
29
+ scenes = {}
30
+ knn_engine_dict = {}
31
+
32
+ @classmethod
33
+ def _load_scene(cls, opt, scene_dir_dict):
34
+ if scene_dir_dict['scene_dir'] not in cls.scenes:
35
+ if opt.info_level == 'rgb':
36
+ assert 0
37
+ elif opt.info_level == 'rgbd':
38
+ scene_dir = scene_dir_dict['scene_dir']
39
+ images_dir = scene_dir_dict['image_dir']
40
+ depth_dir = scene_dir_dict['depth_dir']
41
+ scene = colmap_helper.ColmapWithDepthAsciiReader.read_sfm_scene_given_valid_list_path(scene_dir, images_dir, depth_dir, dataset_config[opt.dataset_name]['valid_list_json'], opt.crop_cam)
42
+ if opt.use_ram:
43
+ scene.read_data_to_ram(['image', 'depth'])
44
+ else:
45
+ raise ValueError()
46
+ knn_engine = knn_search.ReprojRatioKnnSearch(scene)
47
+ cls.scenes[scene_dir_dict['scene_dir']] = scene
48
+ cls.knn_engine_dict[scene_dir_dict['scene_dir']] = knn_engine
49
+ else:
50
+ pass
51
+
52
+
53
+ class MegadepthDataset():
54
+
55
+ def __init__(self, opt, dataset_type):
56
+ assert dataset_type in ['train', 'val', 'test']
57
+ assert len(opt.scenes_name_list) > 0
58
+ self.opt = opt
59
+ self.dataset_type = dataset_type
60
+ self.use_ram = opt.use_ram
61
+ self.scenes_name_list = opt.scenes_name_list
62
+ self.scenes = None
63
+ self.knn_engine_list = None
64
+ self.total_caps_set = None
65
+ self.query_caps_set = None
66
+ self.db_caps_set = None
67
+ self.img_path_to_scene_cap_index_dict = {}
68
+ self.scene_index_to_db_caps_mask_dict = {}
69
+ self._load_scenes()
70
+
71
+ @property
72
+ def num_scenes(self):
73
+ return len(self.scenes)
74
+
75
+ @property
76
+ def num_queries(self):
77
+ return len(self.query_caps_set)
78
+
79
+ @property
80
+ def num_db(self):
81
+ return len(self.db_caps_set)
82
+
83
+ def get_scene_cap_index_by_index(self, index):
84
+ assert index < len(self.query_caps_set)
85
+ img_path = sorted(list(self.query_caps_set))[index]
86
+ scene_cap_index = self.img_path_to_scene_cap_index_dict[img_path]
87
+ return scene_cap_index
88
+
89
+ def _get_common_subset_caps_from_json(self, json_path, total_caps):
90
+ prefix = prefix_of_img_path_for_magedepth(list(total_caps)[0])
91
+ with open(json_path, 'r') as f:
92
+ common_caps = [prefix + cap for cap in json.load(f)]
93
+ common_caps = set(total_caps) & set(common_caps)
94
+ return common_caps
95
+
96
+ def _extend_img_path_to_scene_cap_index_dict(self, img_path_to_cap_index_dict, scene_id):
97
+ for key in img_path_to_cap_index_dict.keys():
98
+ self.img_path_to_scene_cap_index_dict[key] = SceneCapIndex(scene_id, img_path_to_cap_index_dict[key])
99
+
100
+ def _create_scene_index_to_db_caps_mask_dict(self, db_caps_set):
101
+ scene_index_to_db_caps_mask_dict = {}
102
+ for cap in db_caps_set:
103
+ scene_id, cap_id = self.img_path_to_scene_cap_index_dict[cap]
104
+ if scene_id not in scene_index_to_db_caps_mask_dict:
105
+ scene_index_to_db_caps_mask_dict[scene_id] = []
106
+ scene_index_to_db_caps_mask_dict[scene_id].append(cap_id)
107
+ for _k, _v in scene_index_to_db_caps_mask_dict.items():
108
+ scene_index_to_db_caps_mask_dict[_k] = np.array(sorted(_v))
109
+ return scene_index_to_db_caps_mask_dict
110
+
111
+ def _load_scenes(self):
112
+ scenes = []
113
+ knn_engine_list = []
114
+ total_caps_set = set()
115
+ for scene_id, scene_dir_dict in enumerate(self.scenes_name_list):
116
+ MegadepthSceneDataBase._load_scene(self.opt, scene_dir_dict)
117
+ scene = MegadepthSceneDataBase.scenes[scene_dir_dict['scene_dir']]
118
+ knn_engine = MegadepthSceneDataBase.knn_engine_dict[scene_dir_dict['scene_dir']]
119
+ total_caps_set = total_caps_set | set(scene.img_path_to_index_dict.keys())
120
+ self._extend_img_path_to_scene_cap_index_dict(scene.img_path_to_index_dict, scene_id)
121
+ scenes.append(scene)
122
+ knn_engine_list.append(knn_engine)
123
+ self.scenes = scenes
124
+ self.knn_engine_list = knn_engine_list
125
+ self.total_caps_set = total_caps_set
126
+ self.query_caps_set = self._get_common_subset_caps_from_json(dataset_config[self.opt.dataset_name][f'{self.dataset_type}_json'], total_caps_set)
127
+ self.db_caps_set = self._get_common_subset_caps_from_json(dataset_config[self.opt.dataset_name]['train_json'], total_caps_set)
128
+ self.scene_index_to_db_caps_mask_dict = self._create_scene_index_to_db_caps_mask_dict(self.db_caps_set)
129
+
130
+ def get_query_with_knn(self, index):
131
+ scene_index, cap_index = self.get_scene_cap_index_by_index(index)
132
+ query_cap = self.scenes[scene_index].captures[cap_index]
133
+ knn_engine = self.knn_engine_list[scene_index]
134
+ if scene_index in self.scene_index_to_db_caps_mask_dict:
135
+ db_mask = self.scene_index_to_db_caps_mask_dict[scene_index]
136
+ else:
137
+ db_mask = None
138
+ pool = knn_engine.get_knn(query_cap, self.opt.pool_size, db_mask=db_mask)
139
+ nn_caps = random.sample(pool, min(len(pool), self.opt.k_size))
140
+ return query_cap, nn_caps
third_party/COTR/COTR/global_configs/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+
4
+ __location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
5
+ with open(os.path.join(__location__, 'dataset_config.json'), 'r') as f:
6
+ dataset_config = json.load(f)
7
+ with open(os.path.join(__location__, 'commons.json'), 'r') as f:
8
+ general_config = json.load(f)
9
+ # assert os.path.isdir(general_config['out']), f'Please create {general_config["out"]}'
10
+ # assert os.path.isdir(general_config['tb_out']), f'Please create {general_config["tb_out"]}'
third_party/COTR/COTR/global_configs/commons.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"out": "../../out", "tb_out": "../../tb_out"}
third_party/COTR/COTR/global_configs/dataset_config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "megadepth": {
3
+ "valid_list_json": "/media/jiangwei/data_ssd/MegaDepth_v1_SfM/megadepth_valid_list.json",
4
+ "train_json": "/media/jiangwei/data_ssd/MegaDepth_v1_SfM/megadepth_train.json",
5
+ "val_json": "/media/jiangwei/data_ssd/MegaDepth_v1_SfM/megadepth_val.json",
6
+ "test_json": "/media/jiangwei/data_ssd/MegaDepth_v1_SfM/megadepth_test.json",
7
+ "scene_dir": "/media/jiangwei/data_ssd/MegaDepth_v1_SfM/{0}/sparse/manhattan/{1}_rectified/sparse",
8
+ "image_dir": "/media/jiangwei/data_ssd/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/imgs",
9
+ "depth_dir": "/media/jiangwei/data_ssd/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/depths"
10
+ },
11
+
12
+ "megadepth_sushi": {
13
+ "valid_list_json": "/scratch/dataset/megadepth/MegaDepth_v1_SfM/megadepth_valid_list.json",
14
+ "train_json": "/scratch/programs/COTR/sample_data/megadepth_train.json",
15
+ "val_json": "/scratch/programs/COTR/sample_data/megadepth_val.json",
16
+ "test_json": "/scratch/dataset/megadepth/MegaDepth_v1_SfM/megadepth_test.json",
17
+ "scene_dir": "/scratch/dataset/megadepth/MegaDepth_v1_SfM/{0}/sparse/manhattan/{1}_rectified/sparse",
18
+ "image_dir": "/scratch/dataset/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/imgs",
19
+ "depth_dir": "/scratch/dataset/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/depths"
20
+ },
21
+
22
+ "megadepth_sockeye": {
23
+ "valid_list_json": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1_SfM/megadepth_valid_list.json",
24
+ "train_json": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1_SfM/megadepth_train.json",
25
+ "val_json": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1_SfM/megadepth_val.json",
26
+ "test_json": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1_SfM/megadepth_test.json",
27
+ "scene_dir": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1_SfM/{0}/sparse/manhattan/{1}_rectified/sparse",
28
+ "image_dir": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/imgs",
29
+ "depth_dir": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/depths"
30
+ },
31
+
32
+ "megadepth_snubfin": {
33
+ "valid_list_json": "/ubc/cs/research/kmyi/datasets/megadepth/MegaDepth_v1_SfM/megadepth_valid_list.json",
34
+ "train_json": "/ubc/cs/research/kmyi/jw221/programs/COTR/sample_data/megadepth_train.json",
35
+ "val_json": "/ubc/cs/research/kmyi/jw221/programs/COTR/sample_data/megadepth_val.json",
36
+ "test_json": "/ubc/cs/research/kmyi/datasets/megadepth/MegaDepth_v1_SfM/megadepth_test.json",
37
+ "scene_dir": "/ubc/cs/research/kmyi/datasets/megadepth/MegaDepth_v1_SfM/{0}/sparse/manhattan/{1}_rectified/sparse",
38
+ "image_dir": "/ubc/cs/research/kmyi/datasets/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/imgs",
39
+ "depth_dir": "/ubc/cs/research/kmyi/datasets/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/depths"
40
+ }
41
+ }
third_party/COTR/COTR/inference/inference_helper.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from torchvision.transforms import functional as tvtf
7
+ from tqdm import tqdm
8
+ import PIL
9
+
10
+ from COTR.utils import utils, debug_utils
11
+ from COTR.utils.constants import MAX_SIZE
12
+ from COTR.cameras.capture import crop_center_max_np, pad_to_square_np
13
+ from COTR.utils.utils import ImagePatch
14
+
15
+ THRESHOLD_SPARSE = 0.02
16
+ THRESHOLD_PIXELS_RELATIVE = 0.02
17
+ BASE_ZOOM = 1.0
18
+ THRESHOLD_AREA = 0.02
19
+ LARGE_GPU = True
20
+
21
+
22
+ def find_prediction_loop(arr):
23
+ '''
24
+ loop ends at last element
25
+ '''
26
+ assert arr.shape[1] == 2, 'requires shape (N, 2)'
27
+ start_index = np.where(np.prod(arr[:-1] == arr[-1], axis=1))[0][0]
28
+ return arr[start_index:-1]
29
+
30
+
31
+ def two_images_side_by_side(img_a, img_b):
32
+ assert img_a.shape == img_b.shape, f'{img_a.shape} vs {img_b.shape}'
33
+ assert img_a.dtype == img_b.dtype
34
+ h, w, c = img_a.shape
35
+ canvas = np.zeros((h, 2 * w, c), dtype=img_a.dtype)
36
+ canvas[:, 0 * w:1 * w, :] = img_a
37
+ canvas[:, 1 * w:2 * w, :] = img_b
38
+ return canvas
39
+
40
+
41
+ def to_square_patches(img):
42
+ patches = []
43
+ h, w, _ = img.shape
44
+ short = size = min(h, w)
45
+ long = max(h, w)
46
+ if long == short:
47
+ patch_0 = ImagePatch(img[:size, :size], 0, 0, size, size, w, h)
48
+ patches = [patch_0]
49
+ elif long <= size * 2:
50
+ warnings.warn('Spatial smoothness in dense optical flow is lost, but sparse matching and triangulation should be fine')
51
+ patch_0 = ImagePatch(img[:size, :size], 0, 0, size, size, w, h)
52
+ patch_1 = ImagePatch(img[-size:, -size:], w - size, h - size, size, size, w, h)
53
+ patches = [patch_0, patch_1]
54
+ # patches += subdivide_patch(patch_0)
55
+ # patches += subdivide_patch(patch_1)
56
+ else:
57
+ raise NotImplementedError
58
+ return patches
59
+
60
+
61
+ def merge_flow_patches(corrs):
62
+ confidence = np.ones([corrs[0].oh, corrs[0].ow]) * 100
63
+ flow = np.zeros([corrs[0].oh, corrs[0].ow, 2])
64
+ cmap = np.ones([corrs[0].oh, corrs[0].ow]) * -1
65
+ for i, c in enumerate(corrs):
66
+ temp = np.ones([c.oh, c.ow]) * 100
67
+ temp[c.y:c.y + c.h, c.x:c.x + c.w] = c.patch[..., 2]
68
+ tempf = np.zeros([c.oh, c.ow, 2])
69
+ tempf[c.y:c.y + c.h, c.x:c.x + c.w] = c.patch[..., :2]
70
+ min_ind = np.stack([temp, confidence], axis=-1).argmin(axis=-1)
71
+ min_ind = min_ind == 0
72
+ confidence[min_ind] = temp[min_ind]
73
+ flow[min_ind] = tempf[min_ind]
74
+ cmap[min_ind] = i
75
+ return flow, confidence, cmap
76
+
77
+
78
+ def get_patch_centered_at(img, pos, scale=1.0, return_content=True, img_shape=None):
79
+ '''
80
+ pos - [x, y]
81
+ '''
82
+ if img_shape is None:
83
+ img_shape = img.shape
84
+ h, w, _ = img_shape
85
+ short = min(h, w)
86
+ scale = np.clip(scale, 0.0, 1.0)
87
+ size = short * scale
88
+ size = int((size // 2) * 2)
89
+ lu_y = int(pos[1] - size // 2)
90
+ lu_x = int(pos[0] - size // 2)
91
+ if lu_y < 0:
92
+ lu_y -= lu_y
93
+ if lu_x < 0:
94
+ lu_x -= lu_x
95
+ if lu_y + size > h:
96
+ lu_y -= (lu_y + size) - (h)
97
+ if lu_x + size > w:
98
+ lu_x -= (lu_x + size) - (w)
99
+ if return_content:
100
+ return ImagePatch(img[lu_y:lu_y + size, lu_x:lu_x + size], lu_x, lu_y, size, size, w, h)
101
+ else:
102
+ return ImagePatch(None, lu_x, lu_y, size, size, w, h)
103
+
104
+
105
+ def cotr_patch_flow_exhaustive(model, patches_a, patches_b):
106
+ def one_pass(model, img_a, img_b):
107
+ device = next(model.parameters()).device
108
+ assert img_a.shape[0] == img_a.shape[1]
109
+ assert img_b.shape[0] == img_b.shape[1]
110
+ img_a = np.array(PIL.Image.fromarray(img_a).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
111
+ img_b = np.array(PIL.Image.fromarray(img_b).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
112
+ img = two_images_side_by_side(img_a, img_b)
113
+ img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float()[None]
114
+ img = img.to(device)
115
+
116
+ q_list = []
117
+ for i in range(MAX_SIZE):
118
+ queries = []
119
+ for j in range(MAX_SIZE * 2):
120
+ queries.append([(j) / (MAX_SIZE * 2), i / MAX_SIZE])
121
+ queries = np.array(queries)
122
+ q_list.append(queries)
123
+ if LARGE_GPU:
124
+ try:
125
+ queries = torch.from_numpy(np.concatenate(q_list))[None].float().to(device)
126
+ out = model.forward(img, queries)['pred_corrs'].detach().cpu().numpy()[0]
127
+ out_list = out.reshape(MAX_SIZE, MAX_SIZE * 2, -1)
128
+ except:
129
+ assert 0, 'set LARGE_GPU to False'
130
+ else:
131
+ out_list = []
132
+ for q in q_list:
133
+ queries = torch.from_numpy(q)[None].float().to(device)
134
+ out = model.forward(img, queries)['pred_corrs'].detach().cpu().numpy()[0]
135
+ out_list.append(out)
136
+ out_list = np.array(out_list)
137
+ in_grid = torch.from_numpy(np.array(q_list)).float()[None] * 2 - 1
138
+ out_grid = torch.from_numpy(out_list).float()[None] * 2 - 1
139
+ cycle_grid = torch.nn.functional.grid_sample(out_grid.permute(0, 3, 1, 2), out_grid).permute(0, 2, 3, 1)
140
+ confidence = torch.norm(cycle_grid[0, ...] - in_grid[0, ...], dim=-1)
141
+ corr = out_grid[0].clone()
142
+ corr[:, :MAX_SIZE, 0] = corr[:, :MAX_SIZE, 0] * 2 - 1
143
+ corr[:, MAX_SIZE:, 0] = corr[:, MAX_SIZE:, 0] * 2 + 1
144
+ corr = torch.cat([corr, confidence[..., None]], dim=-1).numpy()
145
+ return corr[:, :MAX_SIZE, :], corr[:, MAX_SIZE:, :]
146
+ corrs_a = []
147
+ corrs_b = []
148
+
149
+ for p_i in patches_a:
150
+ for p_j in patches_b:
151
+ c_i, c_j = one_pass(model, p_i.patch, p_j.patch)
152
+ base_corners = np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]])
153
+ real_corners_j = (np.array([[p_j.x, p_j.y], [p_j.x + p_j.w, p_j.y], [p_j.x + p_j.w, p_j.y + p_j.h], [p_j.x, p_j.y + p_j.h]]) / np.array([p_j.ow, p_j.oh])) * 2 + np.array([-1, -1])
154
+ real_corners_i = (np.array([[p_i.x, p_i.y], [p_i.x + p_i.w, p_i.y], [p_i.x + p_i.w, p_i.y + p_i.h], [p_i.x, p_i.y + p_i.h]]) / np.array([p_i.ow, p_i.oh])) * 2 + np.array([-1, -1])
155
+ T_i = cv2.getAffineTransform(base_corners[:3].astype(np.float32), real_corners_j[:3].astype(np.float32))
156
+ T_j = cv2.getAffineTransform(base_corners[:3].astype(np.float32), real_corners_i[:3].astype(np.float32))
157
+ c_i[..., :2] = c_i[..., :2] @ T_i[:2, :2] + T_i[:, 2]
158
+ c_j[..., :2] = c_j[..., :2] @ T_j[:2, :2] + T_j[:, 2]
159
+ c_i = utils.float_image_resize(c_i, (p_i.h, p_i.w))
160
+ c_j = utils.float_image_resize(c_j, (p_j.h, p_j.w))
161
+ c_i = ImagePatch(c_i, p_i.x, p_i.y, p_i.w, p_i.h, p_i.ow, p_i.oh)
162
+ c_j = ImagePatch(c_j, p_j.x, p_j.y, p_j.w, p_j.h, p_j.ow, p_j.oh)
163
+ corrs_a.append(c_i)
164
+ corrs_b.append(c_j)
165
+ return corrs_a, corrs_b
166
+
167
+
168
+ def cotr_flow(model, img_a, img_b):
169
+ # assert img_a.shape[0] == img_a.shape[1]
170
+ # assert img_b.shape[0] == img_b.shape[1]
171
+ patches_a = to_square_patches(img_a)
172
+ patches_b = to_square_patches(img_b)
173
+
174
+ corrs_a, corrs_b = cotr_patch_flow_exhaustive(model, patches_a, patches_b)
175
+ corr_a, con_a, cmap_a = merge_flow_patches(corrs_a)
176
+ corr_b, con_b, cmap_b = merge_flow_patches(corrs_b)
177
+
178
+ resample_a = utils.torch_img_to_np_img(torch.nn.functional.grid_sample(utils.np_img_to_torch_img(img_b)[None].float(),
179
+ torch.from_numpy(corr_a)[None].float())[0])
180
+ resample_b = utils.torch_img_to_np_img(torch.nn.functional.grid_sample(utils.np_img_to_torch_img(img_a)[None].float(),
181
+ torch.from_numpy(corr_b)[None].float())[0])
182
+ return corr_a, con_a, resample_a, corr_b, con_b, resample_b
183
+
184
+
185
+ def cotr_corr_base(model, img_a, img_b, queries_a):
186
+ def one_pass(model, img_a, img_b, queries):
187
+ device = next(model.parameters()).device
188
+ assert img_a.shape[0] == img_a.shape[1]
189
+ assert img_b.shape[0] == img_b.shape[1]
190
+ img_a = np.array(PIL.Image.fromarray(img_a).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
191
+ img_b = np.array(PIL.Image.fromarray(img_b).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
192
+ img = two_images_side_by_side(img_a, img_b)
193
+ img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float()[None]
194
+ img = img.to(device)
195
+
196
+ queries = torch.from_numpy(queries)[None].float().to(device)
197
+ out = model.forward(img, queries)['pred_corrs'].clone().detach()
198
+ cycle = model.forward(img, out)['pred_corrs'].clone().detach()
199
+
200
+ queries = queries.cpu().numpy()[0]
201
+ out = out.cpu().numpy()[0]
202
+ cycle = cycle.cpu().numpy()[0]
203
+ conf = np.linalg.norm(queries - cycle, axis=1, keepdims=True)
204
+ return np.concatenate([out, conf], axis=1)
205
+
206
+ patches_a = to_square_patches(img_a)
207
+ patches_b = to_square_patches(img_b)
208
+ pred_list = []
209
+
210
+ for p_i in patches_a:
211
+ for p_j in patches_b:
212
+ normalized_queries_a = queries_a.copy()
213
+ mask = (normalized_queries_a[:, 0] >= p_i.x) & (normalized_queries_a[:, 1] >= p_i.y) & (normalized_queries_a[:, 0] <= p_i.x + p_i.w) & (normalized_queries_a[:, 1] <= p_i.y + p_i.h)
214
+ normalized_queries_a[:, 0] -= p_i.x
215
+ normalized_queries_a[:, 1] -= p_i.y
216
+ normalized_queries_a[:, 0] /= 2 * p_i.w
217
+ normalized_queries_a[:, 1] /= p_i.h
218
+ pred = one_pass(model, p_i.patch, p_j.patch, normalized_queries_a)
219
+ pred[~mask, 2] = np.inf
220
+ pred[:, 0] -= 0.5
221
+ pred[:, 0] *= 2 * p_j.w
222
+ pred[:, 0] += p_j.x
223
+ pred[:, 1] *= p_j.h
224
+ pred[:, 1] += p_j.y
225
+ pred_list.append(pred)
226
+
227
+ pred_list = np.stack(pred_list).transpose(1, 0, 2)
228
+ out = []
229
+ for item in pred_list:
230
+ out.append(item[np.argmin(item[..., 2], axis=0)])
231
+ out = np.array(out)[..., :2]
232
+ return np.concatenate([queries_a, out], axis=1)
233
+
234
+
235
+ try:
236
+ from vispy import gloo
237
+ from vispy import app
238
+ from vispy.util.ptime import time
239
+ from scipy.spatial import Delaunay
240
+ from vispy.gloo.wrappers import read_pixels
241
+
242
+ app.use_app('glfw')
243
+
244
+
245
+ vertex_shader = """
246
+ attribute vec4 color;
247
+ attribute vec2 position;
248
+ varying vec4 v_color;
249
+ void main()
250
+ {
251
+ gl_Position = vec4(position, 0.0, 1.0);
252
+ v_color = color;
253
+ } """
254
+
255
+ fragment_shader = """
256
+ varying vec4 v_color;
257
+ void main()
258
+ {
259
+ gl_FragColor = v_color;
260
+ } """
261
+
262
+
263
+ class Canvas(app.Canvas):
264
+ def __init__(self, mesh, color, size):
265
+ # We hide the canvas upon creation.
266
+ app.Canvas.__init__(self, show=False, size=size)
267
+ self._t0 = time()
268
+ # Texture where we render the scene.
269
+ self._rendertex = gloo.Texture2D(shape=self.size[::-1] + (4,), internalformat='rgba32f')
270
+ # FBO.
271
+ self._fbo = gloo.FrameBuffer(self._rendertex,
272
+ gloo.RenderBuffer(self.size[::-1]))
273
+ # Regular program that will be rendered to the FBO.
274
+ self.program = gloo.Program(vertex_shader, fragment_shader)
275
+ self.program["position"] = mesh
276
+ self.program['color'] = color
277
+ # We manually draw the hidden canvas.
278
+ self.update()
279
+
280
+ def on_draw(self, event):
281
+ # Render in the FBO.
282
+ with self._fbo:
283
+ gloo.clear('black')
284
+ gloo.set_viewport(0, 0, *self.size)
285
+ self.program.draw()
286
+ # Retrieve the contents of the FBO texture.
287
+ self.im = read_pixels((0, 0, self.size[0], self.size[1]), True, out_type='float')
288
+ self._time = time() - self._t0
289
+ # Immediately exit the application.
290
+ app.quit()
291
+
292
+
293
+ def triangulate_corr(corr, from_shape, to_shape):
294
+ corr = corr.copy()
295
+ to_shape = to_shape[:2]
296
+ from_shape = from_shape[:2]
297
+ corr = corr / np.concatenate([from_shape[::-1], to_shape[::-1]])
298
+ tri = Delaunay(corr[:, :2])
299
+ mesh = corr[:, :2][tri.simplices].astype(np.float32) * 2 - 1
300
+ mesh[..., 1] *= -1
301
+ color = corr[:, 2:][tri.simplices].astype(np.float32)
302
+ color = np.concatenate([color, np.ones_like(color[..., 0:2])], axis=-1)
303
+ c = Canvas(mesh.reshape(-1, 2), color.reshape(-1, 4), size=(from_shape[::-1]))
304
+ app.run()
305
+ render = c.im.copy()
306
+ render = render[..., :2]
307
+ render *= np.array(to_shape[::-1])
308
+ return render
309
+ except:
310
+ print('cannot use vispy, setting triangulate_corr as None')
311
+ triangulate_corr = None
third_party/COTR/COTR/inference/refinement_task.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torchvision.transforms import functional as tvtf
6
+ import imageio
7
+ import PIL
8
+
9
+ from COTR.inference.inference_helper import BASE_ZOOM, THRESHOLD_PIXELS_RELATIVE, get_patch_centered_at, two_images_side_by_side, find_prediction_loop
10
+ from COTR.utils import debug_utils, utils
11
+ from COTR.utils.constants import MAX_SIZE
12
+ from COTR.utils.utils import ImagePatch
13
+
14
+
15
+ class RefinementTask():
16
+ def __init__(self, image_from, image_to, loc_from, loc_to, area_from, area_to, converge_iters, zoom_ins, identifier=None):
17
+ self.identifier = identifier
18
+ self.image_from = image_from
19
+ self.image_to = image_to
20
+ self.loc_from = loc_from
21
+ self.best_loc_to = loc_to
22
+ self.cur_loc_to = loc_to
23
+ self.area_from = area_from
24
+ self.area_to = area_to
25
+ if self.area_from < self.area_to:
26
+ self.s_from = BASE_ZOOM
27
+ self.s_to = BASE_ZOOM * np.sqrt(self.area_to / self.area_from)
28
+ else:
29
+ self.s_to = BASE_ZOOM
30
+ self.s_from = BASE_ZOOM * np.sqrt(self.area_from / self.area_to)
31
+
32
+ self.cur_job = {}
33
+ self.status = 'unfinished'
34
+ self.result = 'unknown'
35
+
36
+ self.converge_iters = converge_iters
37
+ self.zoom_ins = zoom_ins
38
+ self.cur_zoom_idx = 0
39
+ self.cur_iter = 0
40
+ self.total_iter = 0
41
+
42
+ self.loc_to_at_zoom = []
43
+ self.loc_history = [loc_to]
44
+ self.all_loc_to_dict = {}
45
+ self.job_history = []
46
+ self.submitted = False
47
+
48
+ @property
49
+ def cur_zoom(self):
50
+ return self.zoom_ins[self.cur_zoom_idx]
51
+
52
+ @property
53
+ def confidence_scaling_factor(self):
54
+ if self.cur_zoom_idx > 0:
55
+ conf_scaling = float(self.cur_zoom) / float(self.zoom_ins[0])
56
+ else:
57
+ conf_scaling = 1.0
58
+ return conf_scaling
59
+
60
+ def peek(self):
61
+ assert self.status == 'unfinished'
62
+ patch_from = get_patch_centered_at(None, self.loc_from, scale=self.s_from * self.cur_zoom, return_content=False, img_shape=self.image_from.shape)
63
+ patch_to = get_patch_centered_at(None, self.cur_loc_to, scale=self.s_to * self.cur_zoom, return_content=False, img_shape=self.image_to.shape)
64
+ top_job = {'patch_from': patch_from,
65
+ 'patch_to': patch_to,
66
+ 'loc_from': self.loc_from,
67
+ 'loc_to': self.cur_loc_to,
68
+ }
69
+ return top_job
70
+
71
+ def get_task_pilot(self, pilot):
72
+ assert self.status == 'unfinished'
73
+ patch_from = ImagePatch(None, pilot.cur_job['patch_from'].x, pilot.cur_job['patch_from'].y, pilot.cur_job['patch_from'].w, pilot.cur_job['patch_from'].h, pilot.cur_job['patch_from'].ow, pilot.cur_job['patch_from'].oh)
74
+ patch_to = ImagePatch(None, pilot.cur_job['patch_to'].x, pilot.cur_job['patch_to'].y, pilot.cur_job['patch_to'].w, pilot.cur_job['patch_to'].h, pilot.cur_job['patch_to'].ow, pilot.cur_job['patch_to'].oh)
75
+ query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float()
76
+ self.cur_job = {'patch_from': patch_from,
77
+ 'patch_to': patch_to,
78
+ 'loc_from': self.loc_from,
79
+ 'loc_to': self.cur_loc_to,
80
+ 'img': None,
81
+ }
82
+ self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w))
83
+ assert self.submitted == False
84
+ self.submitted = True
85
+ return None, query
86
+
87
+ def get_task_fast(self):
88
+ assert self.status == 'unfinished'
89
+ patch_from = get_patch_centered_at(self.image_from, self.loc_from, scale=self.s_from * self.cur_zoom, return_content=False)
90
+ patch_to = get_patch_centered_at(self.image_to, self.cur_loc_to, scale=self.s_to * self.cur_zoom, return_content=False)
91
+ query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float()
92
+ self.cur_job = {'patch_from': patch_from,
93
+ 'patch_to': patch_to,
94
+ 'loc_from': self.loc_from,
95
+ 'loc_to': self.cur_loc_to,
96
+ 'img': None,
97
+ }
98
+
99
+ self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w))
100
+ assert self.submitted == False
101
+ self.submitted = True
102
+
103
+ return None, query
104
+
105
+ def get_task(self):
106
+ assert self.status == 'unfinished'
107
+ patch_from = get_patch_centered_at(self.image_from, self.loc_from, scale=self.s_from * self.cur_zoom)
108
+ patch_to = get_patch_centered_at(self.image_to, self.cur_loc_to, scale=self.s_to * self.cur_zoom)
109
+
110
+ query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float()
111
+
112
+ img_from = patch_from.patch
113
+ img_to = patch_to.patch
114
+ assert img_from.shape[0] == img_from.shape[1]
115
+ assert img_to.shape[0] == img_to.shape[1]
116
+
117
+ img_from = np.array(PIL.Image.fromarray(img_from).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
118
+ img_to = np.array(PIL.Image.fromarray(img_to).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
119
+ img = two_images_side_by_side(img_from, img_to)
120
+ img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float()
121
+
122
+ self.cur_job = {'patch_from': ImagePatch(None, patch_from.x, patch_from.y, patch_from.w, patch_from.h, patch_from.ow, patch_from.oh),
123
+ 'patch_to': ImagePatch(None, patch_to.x, patch_to.y, patch_to.w, patch_to.h, patch_to.ow, patch_to.oh),
124
+ 'loc_from': self.loc_from,
125
+ 'loc_to': self.cur_loc_to,
126
+ }
127
+
128
+ self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w))
129
+ assert self.submitted == False
130
+ self.submitted = True
131
+
132
+ return img, query
133
+
134
+ def next_zoom(self):
135
+ if self.cur_zoom_idx >= len(self.zoom_ins) - 1:
136
+ self.status = 'finished'
137
+ if self.conclude() is None:
138
+ self.result = 'bad'
139
+ else:
140
+ self.result = 'good'
141
+ self.cur_zoom_idx += 1
142
+ self.cur_iter = 0
143
+ self.loc_to_at_zoom = []
144
+
145
+ def scale_to_loc(self, raw_to_loc):
146
+ raw_to_loc = raw_to_loc.copy()
147
+ patch_b = self.cur_job['patch_to']
148
+ raw_to_loc[0] = (raw_to_loc[0] - 0.5) * 2
149
+ loc_to = raw_to_loc * np.array([patch_b.w, patch_b.h])
150
+ loc_to = loc_to + np.array([patch_b.x, patch_b.y])
151
+ return loc_to
152
+
153
+ def step(self, raw_to_loc):
154
+ assert self.submitted == True
155
+ self.submitted = False
156
+ loc_to = self.scale_to_loc(raw_to_loc)
157
+ self.total_iter += 1
158
+ self.loc_to_at_zoom.append(loc_to)
159
+ self.cur_loc_to = loc_to
160
+ zoom_finished = False
161
+ if self.cur_zoom_idx == len(self.zoom_ins) - 1:
162
+ # converge at the last level
163
+ if len(self.loc_to_at_zoom) >= 2:
164
+ zoom_finished = np.prod(self.loc_to_at_zoom[:-1] == loc_to, axis=1, keepdims=True).any()
165
+ if self.cur_iter >= self.converge_iters - 1:
166
+ zoom_finished = True
167
+ self.cur_iter += 1
168
+ else:
169
+ # finish immediately for other levels
170
+ zoom_finished = True
171
+ if zoom_finished:
172
+ self.all_loc_to_dict[self.cur_zoom] = np.array(self.loc_to_at_zoom).copy()
173
+ last_level_loc_to = self.all_loc_to_dict[self.cur_zoom]
174
+ if len(last_level_loc_to) >= 2:
175
+ has_loop = np.prod(last_level_loc_to[:-1] == last_level_loc_to[-1], axis=1, keepdims=True).any()
176
+ if has_loop:
177
+ loop = find_prediction_loop(last_level_loc_to)
178
+ loc_to = loop.mean(axis=0)
179
+ self.loc_history.append(loc_to)
180
+ self.best_loc_to = loc_to
181
+ self.cur_loc_to = self.best_loc_to
182
+ self.next_zoom()
183
+
184
+ def conclude(self, force=False):
185
+ loc_history = np.array(self.loc_history)
186
+ if (force == False) and (max(loc_history.std(axis=0)) >= THRESHOLD_PIXELS_RELATIVE * max(*self.image_to.shape)):
187
+ return None
188
+ return np.concatenate([self.loc_from, self.best_loc_to])
189
+
190
+ def conclude_intermedia(self):
191
+ return np.concatenate([np.array(self.loc_history), np.array(self.job_history)], axis=1)
third_party/COTR/COTR/inference/sparse_engine.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Inference engine for sparse image pair correspondences
3
+ '''
4
+
5
+ import time
6
+ import random
7
+
8
+ import numpy as np
9
+ import torch
10
+
11
+ from COTR.inference.inference_helper import THRESHOLD_SPARSE, THRESHOLD_AREA, cotr_flow, cotr_corr_base
12
+ from COTR.inference.refinement_task import RefinementTask
13
+ from COTR.utils import debug_utils, utils
14
+ from COTR.cameras.capture import stretch_to_square_np
15
+
16
+
17
+ class SparseEngine():
18
+ def __init__(self, model, batch_size, mode='stretching'):
19
+ assert mode in ['stretching', 'tile']
20
+ self.model = model
21
+ self.batch_size = batch_size
22
+ self.total_tasks = 0
23
+ self.mode = mode
24
+
25
+ def form_batch(self, tasks, zoom=None):
26
+ counter = 0
27
+ task_ref = []
28
+ img_batch = []
29
+ query_batch = []
30
+ for t in tasks:
31
+ if t.status == 'unfinished' and t.submitted == False:
32
+ if zoom is not None and t.cur_zoom != zoom:
33
+ continue
34
+ task_ref.append(t)
35
+ img, query = t.get_task()
36
+ img_batch.append(img)
37
+ query_batch.append(query)
38
+ counter += 1
39
+ if counter >= self.batch_size:
40
+ break
41
+ if len(task_ref) == 0:
42
+ return [], [], []
43
+ img_batch = torch.stack(img_batch)
44
+ query_batch = torch.stack(query_batch)
45
+ return task_ref, img_batch, query_batch
46
+
47
+ def infer_batch(self, img_batch, query_batch):
48
+ self.total_tasks += img_batch.shape[0]
49
+ device = next(self.model.parameters()).device
50
+ img_batch = img_batch.to(device)
51
+ query_batch = query_batch.to(device)
52
+ out = self.model(img_batch, query_batch)['pred_corrs'].clone().detach()
53
+ out = out.cpu().numpy()[:, 0, :]
54
+ if utils.has_nan(out):
55
+ raise ValueError('NaN in prediction')
56
+ return out
57
+
58
+ def conclude_tasks(self, tasks, return_idx=False, force=False,
59
+ offset_x_from=0,
60
+ offset_y_from=0,
61
+ offset_x_to=0,
62
+ offset_y_to=0,
63
+ img_a_shape=None,
64
+ img_b_shape=None):
65
+ corrs = []
66
+ idx = []
67
+ for t in tasks:
68
+ if t.status == 'finished':
69
+ out = t.conclude(force)
70
+ if out is not None:
71
+ corrs.append(np.array(out))
72
+ idx.append(t.identifier)
73
+ corrs = np.array(corrs)
74
+ idx = np.array(idx)
75
+ if corrs.shape[0] > 0:
76
+ corrs -= np.array([offset_x_from, offset_y_from, offset_x_to, offset_y_to])
77
+ if img_a_shape is not None and img_b_shape is not None and not force:
78
+ border_mask = np.prod(corrs < np.concatenate([img_a_shape[::-1], img_b_shape[::-1]]), axis=1)
79
+ border_mask = (np.prod(corrs > np.array([0, 0, 0, 0]), axis=1) * border_mask).astype(np.bool)
80
+ corrs = corrs[border_mask]
81
+ idx = idx[border_mask]
82
+ if return_idx:
83
+ return corrs, idx
84
+ return corrs
85
+
86
+ def num_finished_tasks(self, tasks):
87
+ counter = 0
88
+ for t in tasks:
89
+ if t.status == 'finished':
90
+ counter += 1
91
+ return counter
92
+
93
+ def num_good_tasks(self, tasks):
94
+ counter = 0
95
+ for t in tasks:
96
+ if t.result == 'good':
97
+ counter += 1
98
+ return counter
99
+
100
+ def gen_tasks_w_known_scale(self, img_a, img_b, queries_a, areas, zoom_ins=[1.0], converge_iters=1, max_corrs=1000):
101
+ assert self.mode == 'tile'
102
+ corr_a = cotr_corr_base(self.model, img_a, img_b, queries_a)
103
+ tasks = []
104
+ for c in corr_a:
105
+ tasks.append(RefinementTask(img_a, img_b, c[:2], c[2:], areas[0], areas[1], converge_iters, zoom_ins))
106
+ return tasks
107
+
108
+ def gen_tasks(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, force=False, areas=None):
109
+ if areas is not None:
110
+ assert queries_a is not None
111
+ assert force == True
112
+ assert max_corrs >= queries_a.shape[0]
113
+ return self.gen_tasks_w_known_scale(img_a, img_b, queries_a, areas, zoom_ins=zoom_ins, converge_iters=converge_iters, max_corrs=max_corrs)
114
+ if self.mode == 'stretching':
115
+ if img_a.shape[0] != img_a.shape[1] or img_b.shape[0] != img_b.shape[1]:
116
+ img_a_shape = img_a.shape
117
+ img_b_shape = img_b.shape
118
+ img_a_sq = stretch_to_square_np(img_a.copy())
119
+ img_b_sq = stretch_to_square_np(img_b.copy())
120
+ corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model,
121
+ img_a_sq,
122
+ img_b_sq
123
+ )
124
+ corr_a = utils.float_image_resize(corr_a, img_a_shape[:2])
125
+ con_a = utils.float_image_resize(con_a, img_a_shape[:2])
126
+ resample_a = utils.float_image_resize(resample_a, img_a_shape[:2])
127
+ corr_b = utils.float_image_resize(corr_b, img_b_shape[:2])
128
+ con_b = utils.float_image_resize(con_b, img_b_shape[:2])
129
+ resample_b = utils.float_image_resize(resample_b, img_b_shape[:2])
130
+ else:
131
+ corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model,
132
+ img_a,
133
+ img_b
134
+ )
135
+ elif self.mode == 'tile':
136
+ corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model,
137
+ img_a,
138
+ img_b
139
+ )
140
+ else:
141
+ raise ValueError(f'unsupported mode: {self.mode}')
142
+ mask_a = con_a < THRESHOLD_SPARSE
143
+ mask_b = con_b < THRESHOLD_SPARSE
144
+ area_a = (con_a < THRESHOLD_AREA).sum() / mask_a.size
145
+ area_b = (con_b < THRESHOLD_AREA).sum() / mask_b.size
146
+ tasks = []
147
+
148
+ if queries_a is None:
149
+ index_a = np.where(mask_a)
150
+ index_a = np.array(index_a).T
151
+ index_a = index_a[np.random.choice(len(index_a), min(max_corrs, len(index_a)))]
152
+ index_b = np.where(mask_b)
153
+ index_b = np.array(index_b).T
154
+ index_b = index_b[np.random.choice(len(index_b), min(max_corrs, len(index_b)))]
155
+ for pos in index_a:
156
+ loc_from = pos[::-1]
157
+ loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
158
+ tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins))
159
+ for pos in index_b:
160
+ '''
161
+ trick: suppose to fix the query point location(loc_from),
162
+ but here it fixes the first guess(loc_to).
163
+ '''
164
+ loc_from = pos[::-1]
165
+ loc_to = (corr_b[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_a.shape[:2][::-1]
166
+ tasks.append(RefinementTask(img_a, img_b, loc_to, loc_from, area_a, area_b, converge_iters, zoom_ins))
167
+ else:
168
+ if force:
169
+ for i, loc_from in enumerate(queries_a):
170
+ pos = loc_from[::-1]
171
+ pos = np.array([np.clip(pos[0], 0, corr_a.shape[0] - 1), np.clip(pos[1], 0, corr_a.shape[1] - 1)], dtype=np.int)
172
+ loc_to = (corr_a[tuple(pos)].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
173
+ tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i))
174
+ else:
175
+ for i, loc_from in enumerate(queries_a):
176
+ pos = loc_from[::-1]
177
+ if (pos > np.array(img_a.shape[:2]) - 1).any() or (pos < 0).any():
178
+ continue
179
+ if mask_a[tuple(np.floor(pos).astype('int'))]:
180
+ loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
181
+ tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i))
182
+ if len(tasks) < max_corrs:
183
+ extra = max_corrs - len(tasks)
184
+ counter = 0
185
+ for i, loc_from in enumerate(queries_a):
186
+ if counter >= extra:
187
+ break
188
+ pos = loc_from[::-1]
189
+ if (pos > np.array(img_a.shape[:2]) - 1).any() or (pos < 0).any():
190
+ continue
191
+ if mask_a[tuple(np.floor(pos).astype('int'))] == False:
192
+ loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
193
+ tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i))
194
+ counter += 1
195
+ return tasks
196
+
197
+ def cotr_corr_multiscale(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, force=False, return_tasks_only=False, areas=None):
198
+ '''
199
+ currently only support fixed queries_a
200
+ '''
201
+ img_a = img_a.copy()
202
+ img_b = img_b.copy()
203
+ img_a_shape = img_a.shape[:2]
204
+ img_b_shape = img_b.shape[:2]
205
+ if queries_a is not None:
206
+ queries_a = queries_a.copy()
207
+ tasks = self.gen_tasks(img_a, img_b, zoom_ins, converge_iters, max_corrs, queries_a, force, areas)
208
+ while True:
209
+ num_g = self.num_good_tasks(tasks)
210
+ print(f'{num_g} / {max_corrs} | {self.num_finished_tasks(tasks)} / {len(tasks)}')
211
+ task_ref, img_batch, query_batch = self.form_batch(tasks)
212
+ if len(task_ref) == 0:
213
+ break
214
+ if num_g >= max_corrs:
215
+ break
216
+ out = self.infer_batch(img_batch, query_batch)
217
+ for t, o in zip(task_ref, out):
218
+ t.step(o)
219
+ if return_tasks_only:
220
+ return tasks
221
+ if return_idx:
222
+ corrs, idx = self.conclude_tasks(tasks, return_idx=True, force=force,
223
+ img_a_shape=img_a_shape,
224
+ img_b_shape=img_b_shape,)
225
+ corrs = corrs[:max_corrs]
226
+ idx = idx[:max_corrs]
227
+ return corrs, idx
228
+ else:
229
+ corrs = self.conclude_tasks(tasks, force=force,
230
+ img_a_shape=img_a_shape,
231
+ img_b_shape=img_b_shape,)
232
+ corrs = corrs[:max_corrs]
233
+ return corrs
234
+
235
+ def cotr_corr_multiscale_with_cycle_consistency(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, return_cycle_error=False):
236
+ EXTRACTION_RATE = 0.3
237
+ temp_max_corrs = int(max_corrs / EXTRACTION_RATE)
238
+ if queries_a is not None:
239
+ temp_max_corrs = min(temp_max_corrs, queries_a.shape[0])
240
+ queries_a = queries_a.copy()
241
+ corr_f, idx_f = self.cotr_corr_multiscale(img_a.copy(), img_b.copy(),
242
+ zoom_ins=zoom_ins,
243
+ converge_iters=converge_iters,
244
+ max_corrs=temp_max_corrs,
245
+ queries_a=queries_a,
246
+ return_idx=True)
247
+ assert corr_f.shape[0] > 0
248
+ corr_b, idx_b = self.cotr_corr_multiscale(img_b.copy(), img_a.copy(),
249
+ zoom_ins=zoom_ins,
250
+ converge_iters=converge_iters,
251
+ max_corrs=corr_f.shape[0],
252
+ queries_a=corr_f[:, 2:].copy(),
253
+ return_idx=True)
254
+ assert corr_b.shape[0] > 0
255
+ cycle_errors = np.linalg.norm(corr_f[idx_b][:, :2] - corr_b[:, 2:], axis=1)
256
+ order = np.argsort(cycle_errors)
257
+ out = [corr_f[idx_b][order][:max_corrs]]
258
+ if return_idx:
259
+ out.append(idx_f[idx_b][order][:max_corrs])
260
+ if return_cycle_error:
261
+ out.append(cycle_errors[order][:max_corrs])
262
+ if len(out) == 1:
263
+ out = out[0]
264
+ return out
265
+
266
+
267
+ class FasterSparseEngine(SparseEngine):
268
+ '''
269
+ search and merge nearby tasks to accelerate inference speed.
270
+ It will make spatial accuracy slightly worse.
271
+ '''
272
+
273
+ def __init__(self, model, batch_size, mode='stretching', max_load=256):
274
+ super().__init__(model, batch_size, mode=mode)
275
+ self.max_load = max_load
276
+
277
+ def infer_batch_grouped(self, img_batch, query_batch):
278
+ device = next(self.model.parameters()).device
279
+ img_batch = img_batch.to(device)
280
+ query_batch = query_batch.to(device)
281
+ out = self.model(img_batch, query_batch)['pred_corrs'].clone().detach().cpu().numpy()
282
+ return out
283
+
284
+ def get_tasks_map(self, zoom, tasks):
285
+ maps = []
286
+ ids = []
287
+ for i, t in enumerate(tasks):
288
+ if t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom:
289
+ t_info = t.peek()
290
+ point = np.concatenate([t_info['loc_from'], t_info['loc_to']])
291
+ maps.append(point)
292
+ ids.append(i)
293
+ return np.array(maps), np.array(ids)
294
+
295
+ def form_squad(self, zoom, pilot, pilot_id, tasks, tasks_map, task_ids, bookkeeping):
296
+ assert pilot.status == 'unfinished' and pilot.submitted == False and pilot.cur_zoom == zoom
297
+ SAFE_AREA = 0.5
298
+ pilot_info = pilot.peek()
299
+ pilot_from_center_x = pilot_info['patch_from'].x + pilot_info['patch_from'].w/2
300
+ pilot_from_center_y = pilot_info['patch_from'].y + pilot_info['patch_from'].h/2
301
+ pilot_from_left = pilot_from_center_x - pilot_info['patch_from'].w/2 * SAFE_AREA
302
+ pilot_from_right = pilot_from_center_x + pilot_info['patch_from'].w/2 * SAFE_AREA
303
+ pilot_from_upper = pilot_from_center_y - pilot_info['patch_from'].h/2 * SAFE_AREA
304
+ pilot_from_lower = pilot_from_center_y + pilot_info['patch_from'].h/2 * SAFE_AREA
305
+
306
+ pilot_to_center_x = pilot_info['patch_to'].x + pilot_info['patch_to'].w/2
307
+ pilot_to_center_y = pilot_info['patch_to'].y + pilot_info['patch_to'].h/2
308
+ pilot_to_left = pilot_to_center_x - pilot_info['patch_to'].w/2 * SAFE_AREA
309
+ pilot_to_right = pilot_to_center_x + pilot_info['patch_to'].w/2 * SAFE_AREA
310
+ pilot_to_upper = pilot_to_center_y - pilot_info['patch_to'].h/2 * SAFE_AREA
311
+ pilot_to_lower = pilot_to_center_y + pilot_info['patch_to'].h/2 * SAFE_AREA
312
+
313
+ img, query = pilot.get_task()
314
+ assert pilot.submitted == True
315
+ members = [pilot]
316
+ queries = [query]
317
+ bookkeeping[pilot_id] = False
318
+
319
+ loads = np.where(((tasks_map[:, 0] > pilot_from_left) &
320
+ (tasks_map[:, 0] < pilot_from_right) &
321
+ (tasks_map[:, 1] > pilot_from_upper) &
322
+ (tasks_map[:, 1] < pilot_from_lower) &
323
+ (tasks_map[:, 2] > pilot_to_left) &
324
+ (tasks_map[:, 2] < pilot_to_right) &
325
+ (tasks_map[:, 3] > pilot_to_upper) &
326
+ (tasks_map[:, 3] < pilot_to_lower)) *
327
+ bookkeeping)[0][: self.max_load]
328
+
329
+ for ti in task_ids[loads]:
330
+ t = tasks[ti]
331
+ assert t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom
332
+ _, query = t.get_task_pilot(pilot)
333
+ members.append(t)
334
+ queries.append(query)
335
+ queries = torch.stack(queries, axis=1)
336
+ bookkeeping[loads] = False
337
+ return members, img, queries, bookkeeping
338
+
339
+ def form_grouped_batch(self, zoom, tasks):
340
+ counter = 0
341
+ task_ref = []
342
+ img_batch = []
343
+ query_batch = []
344
+ tasks_map, task_ids = self.get_tasks_map(zoom, tasks)
345
+ shuffle = np.random.permutation(tasks_map.shape[0])
346
+ tasks_map = np.take(tasks_map, shuffle, axis=0)
347
+ task_ids = np.take(task_ids, shuffle, axis=0)
348
+ bookkeeping = np.ones_like(task_ids).astype(bool)
349
+
350
+ for i, ti in enumerate(task_ids):
351
+ t = tasks[ti]
352
+ if t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom:
353
+ members, img, queries, bookkeeping = self.form_squad(zoom, t, i, tasks, tasks_map, task_ids, bookkeeping)
354
+ task_ref.append(members)
355
+ img_batch.append(img)
356
+ query_batch.append(queries)
357
+ counter += 1
358
+ if counter >= self.batch_size:
359
+ break
360
+ if len(task_ref) == 0:
361
+ return [], [], []
362
+
363
+ max_len = max([q.shape[1] for q in query_batch])
364
+ for i in range(len(query_batch)):
365
+ q = query_batch[i]
366
+ query_batch[i] = torch.cat([q, torch.zeros([1, max_len - q.shape[1], 2])], axis=1)
367
+ img_batch = torch.stack(img_batch)
368
+ query_batch = torch.cat(query_batch)
369
+ return task_ref, img_batch, query_batch
370
+
371
+ def cotr_corr_multiscale(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, force=False, return_tasks_only=False, areas=None):
372
+ '''
373
+ currently only support fixed queries_a
374
+ '''
375
+ img_a = img_a.copy()
376
+ img_b = img_b.copy()
377
+ img_a_shape = img_a.shape[:2]
378
+ img_b_shape = img_b.shape[:2]
379
+ if queries_a is not None:
380
+ queries_a = queries_a.copy()
381
+ tasks = self.gen_tasks(img_a, img_b, zoom_ins, converge_iters, max_corrs, queries_a, force, areas)
382
+ for zm in zoom_ins:
383
+ print(f'======= Zoom: {zm} ======')
384
+ while True:
385
+ num_g = self.num_good_tasks(tasks)
386
+ task_ref, img_batch, query_batch = self.form_grouped_batch(zm, tasks)
387
+ if len(task_ref) == 0:
388
+ break
389
+ if num_g >= max_corrs:
390
+ break
391
+ out = self.infer_batch_grouped(img_batch, query_batch)
392
+ num_steps = 0
393
+ for i, temp in enumerate(task_ref):
394
+ for j, t in enumerate(temp):
395
+ t.step(out[i, j])
396
+ num_steps += 1
397
+ print(f'solved {num_steps} sub-tasks in one invocation with {img_batch.shape[0]} image pairs')
398
+ if num_steps <= self.batch_size:
399
+ break
400
+ # Rollback to default inference, because of too few valid tasks can be grouped together.
401
+ while True:
402
+ num_g = self.num_good_tasks(tasks)
403
+ print(f'{num_g} / {max_corrs} | {self.num_finished_tasks(tasks)} / {len(tasks)}')
404
+ task_ref, img_batch, query_batch = self.form_batch(tasks, zm)
405
+ if len(task_ref) == 0:
406
+ break
407
+ if num_g >= max_corrs:
408
+ break
409
+ out = self.infer_batch(img_batch, query_batch)
410
+ for t, o in zip(task_ref, out):
411
+ t.step(o)
412
+
413
+ if return_tasks_only:
414
+ return tasks
415
+ if return_idx:
416
+ corrs, idx = self.conclude_tasks(tasks, return_idx=True, force=force,
417
+ img_a_shape=img_a_shape,
418
+ img_b_shape=img_b_shape,)
419
+ corrs = corrs[:max_corrs]
420
+ idx = idx[:max_corrs]
421
+ return corrs, idx
422
+ else:
423
+ corrs = self.conclude_tasks(tasks, force=force,
424
+ img_a_shape=img_a_shape,
425
+ img_b_shape=img_b_shape,)
426
+ corrs = corrs[:max_corrs]
427
+ return corrs
third_party/COTR/COTR/models/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ The COTR model is modified from DETR code base.
3
+ https://github.com/facebookresearch/detr
4
+ '''
5
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
6
+ from .cotr_model import build
7
+
8
+
9
+ def build_model(args):
10
+ return build(args)
third_party/COTR/COTR/models/backbone.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Backbone modules.
4
+ """
5
+ from collections import OrderedDict
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision
10
+ from torch import nn
11
+ from torchvision.models._utils import IntermediateLayerGetter
12
+ from typing import Dict, List
13
+
14
+ from .misc import NestedTensor
15
+
16
+ from .position_encoding import build_position_encoding
17
+ from COTR.utils import debug_utils, constants
18
+
19
+
20
+ class FrozenBatchNorm2d(torch.nn.Module):
21
+ """
22
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
23
+
24
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt,
25
+ without which any other models than torchvision.models.resnet[18,34,50,101]
26
+ produce nans.
27
+ """
28
+
29
+ def __init__(self, n):
30
+ super(FrozenBatchNorm2d, self).__init__()
31
+ self.register_buffer("weight", torch.ones(n))
32
+ self.register_buffer("bias", torch.zeros(n))
33
+ self.register_buffer("running_mean", torch.zeros(n))
34
+ self.register_buffer("running_var", torch.ones(n))
35
+
36
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
37
+ missing_keys, unexpected_keys, error_msgs):
38
+ num_batches_tracked_key = prefix + 'num_batches_tracked'
39
+ if num_batches_tracked_key in state_dict:
40
+ del state_dict[num_batches_tracked_key]
41
+
42
+ super(FrozenBatchNorm2d, self)._load_from_state_dict(
43
+ state_dict, prefix, local_metadata, strict,
44
+ missing_keys, unexpected_keys, error_msgs)
45
+
46
+ def forward(self, x):
47
+ # move reshapes to the beginning
48
+ # to make it fuser-friendly
49
+ w = self.weight.reshape(1, -1, 1, 1)
50
+ b = self.bias.reshape(1, -1, 1, 1)
51
+ rv = self.running_var.reshape(1, -1, 1, 1)
52
+ rm = self.running_mean.reshape(1, -1, 1, 1)
53
+ eps = 1e-5
54
+ scale = w * (rv + eps).rsqrt()
55
+ bias = b - rm * scale
56
+ return x * scale + bias
57
+
58
+
59
+ class BackboneBase(nn.Module):
60
+
61
+ def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool, layer='layer3'):
62
+ super().__init__()
63
+ for name, parameter in backbone.named_parameters():
64
+ if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
65
+ parameter.requires_grad_(False)
66
+ # print(f'freeze {name}')
67
+ if return_interm_layers:
68
+ return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
69
+ else:
70
+ return_layers = {layer: "0"}
71
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
72
+ self.num_channels = num_channels
73
+
74
+ def forward_raw(self, x):
75
+ y = self.body(x)
76
+ assert len(y.keys()) == 1
77
+ return y['0']
78
+
79
+ def forward(self, tensor_list: NestedTensor):
80
+ assert tensor_list.tensors.shape[-2:] == (constants.MAX_SIZE, constants.MAX_SIZE * 2)
81
+ left = self.body(tensor_list.tensors[..., 0:constants.MAX_SIZE])
82
+ right = self.body(tensor_list.tensors[..., constants.MAX_SIZE:2 * constants.MAX_SIZE])
83
+ xs = {}
84
+ for k in left.keys():
85
+ xs[k] = torch.cat([left[k], right[k]], dim=-1)
86
+ out: Dict[str, NestedTensor] = {}
87
+ for name, x in xs.items():
88
+ m = tensor_list.mask
89
+ assert m is not None
90
+ mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
91
+ out[name] = NestedTensor(x, mask)
92
+ return out
93
+
94
+
95
+ class Backbone(BackboneBase):
96
+ """ResNet backbone with frozen BatchNorm."""
97
+
98
+ def __init__(self, name: str,
99
+ train_backbone: bool,
100
+ return_interm_layers: bool,
101
+ dilation: bool,
102
+ layer='layer3',
103
+ num_channels=1024):
104
+ backbone = getattr(torchvision.models, name)(
105
+ replace_stride_with_dilation=[False, False, dilation],
106
+ pretrained=True, norm_layer=FrozenBatchNorm2d)
107
+ super().__init__(backbone, train_backbone, num_channels, return_interm_layers, layer)
108
+
109
+
110
+ class Joiner(nn.Sequential):
111
+ def __init__(self, backbone, position_embedding):
112
+ super().__init__(backbone, position_embedding)
113
+
114
+ def forward(self, tensor_list: NestedTensor):
115
+ xs = self[0](tensor_list)
116
+ out: List[NestedTensor] = []
117
+ pos = []
118
+ for name, x in xs.items():
119
+ out.append(x)
120
+ # position encoding
121
+ pos.append(self[1](x).to(x.tensors.dtype))
122
+
123
+ return out, pos
124
+
125
+
126
+ def build_backbone(args):
127
+ position_embedding = build_position_encoding(args)
128
+ if hasattr(args, 'lr_backbone'):
129
+ train_backbone = args.lr_backbone > 0
130
+ else:
131
+ train_backbone = False
132
+ backbone = Backbone(args.backbone, train_backbone, False, args.dilation, layer=args.layer, num_channels=args.dim_feedforward)
133
+ model = Joiner(backbone, position_embedding)
134
+ model.num_channels = backbone.num_channels
135
+ return model
third_party/COTR/COTR/models/cotr_model.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from COTR.utils import debug_utils, constants, utils
9
+ from .misc import (NestedTensor, nested_tensor_from_tensor_list)
10
+ from .backbone import build_backbone
11
+ from .transformer import build_transformer
12
+ from .position_encoding import NerfPositionalEncoding, MLP
13
+
14
+
15
+ class COTR(nn.Module):
16
+
17
+ def __init__(self, backbone, transformer, sine_type='lin_sine'):
18
+ super().__init__()
19
+ self.transformer = transformer
20
+ hidden_dim = transformer.d_model
21
+ self.corr_embed = MLP(hidden_dim, hidden_dim, 2, 3)
22
+ self.query_proj = NerfPositionalEncoding(hidden_dim // 4, sine_type)
23
+ self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
24
+ self.backbone = backbone
25
+
26
+ def forward(self, samples: NestedTensor, queries):
27
+ if isinstance(samples, (list, torch.Tensor)):
28
+ samples = nested_tensor_from_tensor_list(samples)
29
+ features, pos = self.backbone(samples)
30
+
31
+ src, mask = features[-1].decompose()
32
+ assert mask is not None
33
+ _b, _q, _ = queries.shape
34
+ queries = queries.reshape(-1, 2)
35
+ queries = self.query_proj(queries).reshape(_b, _q, -1)
36
+ queries = queries.permute(1, 0, 2)
37
+ hs = self.transformer(self.input_proj(src), mask, queries, pos[-1])[0]
38
+ outputs_corr = self.corr_embed(hs)
39
+ out = {'pred_corrs': outputs_corr[-1]}
40
+ return out
41
+
42
+
43
+ def build(args):
44
+ backbone = build_backbone(args)
45
+ transformer = build_transformer(args)
46
+ model = COTR(
47
+ backbone,
48
+ transformer,
49
+ sine_type=args.position_embedding,
50
+ )
51
+ return model
third_party/COTR/COTR/models/misc.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Misc functions, including distributed helpers.
4
+
5
+ Mostly copy-paste from torchvision references.
6
+ """
7
+ import os
8
+ import subprocess
9
+ import time
10
+ from collections import defaultdict, deque
11
+ import datetime
12
+ import pickle
13
+ from typing import Optional, List
14
+
15
+ import torch
16
+ import torch.distributed as dist
17
+ from torch import Tensor
18
+
19
+ # needed due to empty tensor bug in pytorch and torchvision 0.5
20
+ import torchvision
21
+ if float(torchvision.__version__.split('.')[1]) < 7:
22
+ from torchvision.ops import _new_empty_tensor
23
+ from torchvision.ops.misc import _output_size
24
+
25
+
26
+ def _max_by_axis(the_list):
27
+ # type: (List[List[int]]) -> List[int]
28
+ maxes = the_list[0]
29
+ for sublist in the_list[1:]:
30
+ for index, item in enumerate(sublist):
31
+ maxes[index] = max(maxes[index], item)
32
+ return maxes
33
+
34
+
35
+ class NestedTensor(object):
36
+ def __init__(self, tensors, mask: Optional[Tensor]):
37
+ self.tensors = tensors
38
+ self.mask = mask
39
+
40
+ def to(self, device):
41
+ # type: (Device) -> NestedTensor # noqa
42
+ cast_tensor = self.tensors.to(device)
43
+ mask = self.mask
44
+ if mask is not None:
45
+ assert mask is not None
46
+ cast_mask = mask.to(device)
47
+ else:
48
+ cast_mask = None
49
+ return NestedTensor(cast_tensor, cast_mask)
50
+
51
+ def decompose(self):
52
+ return self.tensors, self.mask
53
+
54
+ def __repr__(self):
55
+ return str(self.tensors)
56
+
57
+
58
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
59
+ # TODO make this more general
60
+ if tensor_list[0].ndim == 3:
61
+ if torchvision._is_tracing():
62
+ # nested_tensor_from_tensor_list() does not export well to ONNX
63
+ # call _onnx_nested_tensor_from_tensor_list() instead
64
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
65
+
66
+ # TODO make it support different-sized images
67
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
68
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
69
+ batch_shape = [len(tensor_list)] + max_size
70
+ b, c, h, w = batch_shape
71
+ dtype = tensor_list[0].dtype
72
+ device = tensor_list[0].device
73
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
74
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
75
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
76
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
77
+ m[: img.shape[1], :img.shape[2]] = False
78
+ else:
79
+ raise ValueError('not supported')
80
+ return NestedTensor(tensor, mask)
81
+
82
+
83
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
84
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
85
+ @torch.jit.unused
86
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
87
+ max_size = []
88
+ for i in range(tensor_list[0].dim()):
89
+ max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
90
+ max_size.append(max_size_i)
91
+ max_size = tuple(max_size)
92
+
93
+ # work around for
94
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
95
+ # m[: img.shape[1], :img.shape[2]] = False
96
+ # which is not yet supported in onnx
97
+ padded_imgs = []
98
+ padded_masks = []
99
+ for img in tensor_list:
100
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
101
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
102
+ padded_imgs.append(padded_img)
103
+
104
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
105
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
106
+ padded_masks.append(padded_mask.to(torch.bool))
107
+
108
+ tensor = torch.stack(padded_imgs)
109
+ mask = torch.stack(padded_masks)
110
+
111
+ return NestedTensor(tensor, mask=mask)
112
+
third_party/COTR/COTR/models/position_encoding.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+
10
+ from .misc import NestedTensor
11
+ from COTR.utils import debug_utils
12
+
13
+
14
+ class MLP(nn.Module):
15
+ """ Very simple multi-layer perceptron (also called FFN)"""
16
+
17
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
18
+ super().__init__()
19
+ self.num_layers = num_layers
20
+ h = [hidden_dim] * (num_layers - 1)
21
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
22
+
23
+ def forward(self, x):
24
+ for i, layer in enumerate(self.layers):
25
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
26
+ return x
27
+
28
+
29
+ class NerfPositionalEncoding(nn.Module):
30
+ def __init__(self, depth=10, sine_type='lin_sine'):
31
+ '''
32
+ out_dim = in_dim * depth * 2
33
+ '''
34
+ super().__init__()
35
+ if sine_type == 'lin_sine':
36
+ self.bases = [i+1 for i in range(depth)]
37
+ elif sine_type == 'exp_sine':
38
+ self.bases = [2**i for i in range(depth)]
39
+ print(f'using {sine_type} as positional encoding')
40
+
41
+ @torch.no_grad()
42
+ def forward(self, inputs):
43
+ out = torch.cat([torch.sin(i * math.pi * inputs) for i in self.bases] + [torch.cos(i * math.pi * inputs) for i in self.bases], axis=-1)
44
+ assert torch.isnan(out).any() == False
45
+ return out
46
+
47
+
48
+ class PositionEmbeddingSine(nn.Module):
49
+ """
50
+ This is a more standard version of the position embedding, very similar to the one
51
+ used by the Attention is all you need paper, generalized to work on images.
52
+ """
53
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None, sine_type='lin_sine'):
54
+ super().__init__()
55
+ self.num_pos_feats = num_pos_feats
56
+ self.temperature = temperature
57
+ self.normalize = normalize
58
+ self.sine = NerfPositionalEncoding(num_pos_feats//2, sine_type)
59
+
60
+ @torch.no_grad()
61
+ def forward(self, tensor_list: NestedTensor):
62
+ x = tensor_list.tensors
63
+ mask = tensor_list.mask
64
+ assert mask is not None
65
+ not_mask = ~mask
66
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
67
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
68
+ eps = 1e-6
69
+ y_embed = (y_embed-0.5) / (y_embed[:, -1:, :] + eps)
70
+ x_embed = (x_embed-0.5) / (x_embed[:, :, -1:] + eps)
71
+ pos = torch.stack([x_embed, y_embed], dim=-1)
72
+ return self.sine(pos).permute(0, 3, 1, 2)
73
+
74
+
75
+ def build_position_encoding(args):
76
+ N_steps = args.hidden_dim // 2
77
+ if args.position_embedding in ('lin_sine', 'exp_sine'):
78
+ # TODO find a better way of exposing other arguments
79
+ position_embedding = PositionEmbeddingSine(N_steps, normalize=True, sine_type=args.position_embedding)
80
+ else:
81
+ raise ValueError(f"not supported {args.position_embedding}")
82
+
83
+ return position_embedding
third_party/COTR/COTR/models/transformer.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ COTR/DETR Transformer class.
4
+
5
+ Copy-paste from torch.nn.Transformer with modifications:
6
+ * positional encodings are passed in MHattention
7
+ * extra LN at the end of encoder is removed
8
+ * decoder returns a stack of activations from all decoding layers
9
+ """
10
+ import copy
11
+ from typing import Optional, List
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from torch import nn, Tensor
16
+
17
+ from COTR.utils import debug_utils
18
+
19
+
20
+ class Transformer(nn.Module):
21
+
22
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
23
+ num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
24
+ activation="relu", return_intermediate_dec=False):
25
+ super().__init__()
26
+
27
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
28
+ dropout, activation)
29
+ self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers)
30
+
31
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
32
+ dropout, activation)
33
+ decoder_norm = nn.LayerNorm(d_model)
34
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
35
+ return_intermediate=return_intermediate_dec)
36
+
37
+ self._reset_parameters()
38
+
39
+ self.d_model = d_model
40
+ self.nhead = nhead
41
+
42
+ def _reset_parameters(self):
43
+ for p in self.parameters():
44
+ if p.dim() > 1:
45
+ nn.init.xavier_uniform_(p)
46
+
47
+ def forward(self, src, mask, query_embed, pos_embed):
48
+ # flatten NxCxHxW to HWxNxC
49
+ bs, c, h, w = src.shape
50
+ src = src.flatten(2).permute(2, 0, 1)
51
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
52
+ mask = mask.flatten(1)
53
+
54
+ tgt = torch.zeros_like(query_embed)
55
+ memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
56
+ hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
57
+ pos=pos_embed, query_pos=query_embed)
58
+ return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
59
+
60
+
61
+ class TransformerEncoder(nn.Module):
62
+
63
+ def __init__(self, encoder_layer, num_layers):
64
+ super().__init__()
65
+ self.layers = _get_clones(encoder_layer, num_layers)
66
+ self.num_layers = num_layers
67
+
68
+ def forward(self, src,
69
+ mask: Optional[Tensor] = None,
70
+ src_key_padding_mask: Optional[Tensor] = None,
71
+ pos: Optional[Tensor] = None):
72
+ output = src
73
+
74
+ for layer in self.layers:
75
+ output = layer(output, src_mask=mask,
76
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
77
+
78
+ return output
79
+
80
+
81
+ class TransformerDecoder(nn.Module):
82
+
83
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
84
+ super().__init__()
85
+ self.layers = _get_clones(decoder_layer, num_layers)
86
+ self.num_layers = num_layers
87
+ self.norm = norm
88
+ self.return_intermediate = return_intermediate
89
+
90
+ def forward(self, tgt, memory,
91
+ tgt_mask: Optional[Tensor] = None,
92
+ memory_mask: Optional[Tensor] = None,
93
+ tgt_key_padding_mask: Optional[Tensor] = None,
94
+ memory_key_padding_mask: Optional[Tensor] = None,
95
+ pos: Optional[Tensor] = None,
96
+ query_pos: Optional[Tensor] = None):
97
+ output = tgt
98
+
99
+ intermediate = []
100
+
101
+ for layer in self.layers:
102
+ output = layer(output, memory, tgt_mask=tgt_mask,
103
+ memory_mask=memory_mask,
104
+ tgt_key_padding_mask=tgt_key_padding_mask,
105
+ memory_key_padding_mask=memory_key_padding_mask,
106
+ pos=pos, query_pos=query_pos)
107
+ if self.return_intermediate:
108
+ intermediate.append(self.norm(output))
109
+
110
+ if self.norm is not None:
111
+ output = self.norm(output)
112
+ if self.return_intermediate:
113
+ intermediate.pop()
114
+ intermediate.append(output)
115
+
116
+ if self.return_intermediate:
117
+ return torch.stack(intermediate)
118
+
119
+ return output.unsqueeze(0)
120
+
121
+
122
+ class TransformerEncoderLayer(nn.Module):
123
+
124
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
125
+ activation="relu"):
126
+ super().__init__()
127
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
128
+ # Implementation of Feedforward model
129
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
130
+ self.dropout = nn.Dropout(dropout)
131
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
132
+
133
+ self.norm1 = nn.LayerNorm(d_model)
134
+ self.norm2 = nn.LayerNorm(d_model)
135
+ self.dropout1 = nn.Dropout(dropout)
136
+ self.dropout2 = nn.Dropout(dropout)
137
+
138
+ self.activation = _get_activation_fn(activation)
139
+
140
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
141
+ return tensor if pos is None else tensor + pos
142
+
143
+ def forward(self,
144
+ src,
145
+ src_mask: Optional[Tensor] = None,
146
+ src_key_padding_mask: Optional[Tensor] = None,
147
+ pos: Optional[Tensor] = None):
148
+ q = k = self.with_pos_embed(src, pos)
149
+ src2 = self.self_attn(query=q,
150
+ key=k,
151
+ value=src,
152
+ attn_mask=src_mask,
153
+ key_padding_mask=src_key_padding_mask)[0]
154
+ src = src + self.dropout1(src2)
155
+ src = self.norm1(src)
156
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
157
+ src = src + self.dropout2(src2)
158
+ src = self.norm2(src)
159
+ return src
160
+
161
+
162
+ class TransformerDecoderLayer(nn.Module):
163
+
164
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
165
+ activation="relu"):
166
+ super().__init__()
167
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
168
+ # Implementation of Feedforward model
169
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
170
+ self.dropout = nn.Dropout(dropout)
171
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
172
+
173
+ self.norm1 = nn.LayerNorm(d_model)
174
+ self.norm2 = nn.LayerNorm(d_model)
175
+ self.norm3 = nn.LayerNorm(d_model)
176
+ self.dropout1 = nn.Dropout(dropout)
177
+ self.dropout2 = nn.Dropout(dropout)
178
+ self.dropout3 = nn.Dropout(dropout)
179
+
180
+ self.activation = _get_activation_fn(activation)
181
+
182
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
183
+ return tensor if pos is None else tensor + pos
184
+
185
+ def forward(self, tgt, memory,
186
+ tgt_mask: Optional[Tensor] = None,
187
+ memory_mask: Optional[Tensor] = None,
188
+ tgt_key_padding_mask: Optional[Tensor] = None,
189
+ memory_key_padding_mask: Optional[Tensor] = None,
190
+ pos: Optional[Tensor] = None,
191
+ query_pos: Optional[Tensor] = None):
192
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
193
+ key=self.with_pos_embed(memory, pos),
194
+ value=memory, attn_mask=memory_mask,
195
+ key_padding_mask=memory_key_padding_mask)[0]
196
+ tgt = tgt + self.dropout2(tgt2)
197
+ tgt = self.norm2(tgt)
198
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
199
+ tgt = tgt + self.dropout3(tgt2)
200
+ tgt = self.norm3(tgt)
201
+ return tgt
202
+
203
+
204
+ def _get_clones(module, N):
205
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
206
+
207
+
208
+ def build_transformer(args):
209
+ return Transformer(
210
+ d_model=args.hidden_dim,
211
+ dropout=args.dropout,
212
+ nhead=args.nheads,
213
+ dim_feedforward=args.dim_feedforward,
214
+ num_encoder_layers=args.enc_layers,
215
+ num_decoder_layers=args.dec_layers,
216
+ return_intermediate_dec=True,
217
+ )
218
+
219
+
220
+ def _get_activation_fn(activation):
221
+ """Return an activation function given a string"""
222
+ if activation == "relu":
223
+ return F.relu
224
+ if activation == "gelu":
225
+ return F.gelu
226
+ if activation == "glu":
227
+ return F.glu
228
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
third_party/COTR/COTR/options/options.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ import json
4
+ import os
5
+
6
+
7
+ from COTR.options.options_utils import str2bool
8
+ from COTR.options import options_utils
9
+ from COTR.global_configs import general_config, dataset_config
10
+ from COTR.utils import debug_utils
11
+
12
+
13
+ def set_general_arguments(parser):
14
+ general_arg = parser.add_argument_group('General')
15
+ general_arg.add_argument('--confirm', type=str2bool,
16
+ default=True, help='promote confirmation for user')
17
+ general_arg.add_argument('--use_cuda', type=str2bool,
18
+ default=True, help='use cuda')
19
+ general_arg.add_argument('--use_cc', type=str2bool,
20
+ default=False, help='use computecanada')
21
+
22
+
23
+ def set_dataset_arguments(parser):
24
+ data_arg = parser.add_argument_group('Data')
25
+ data_arg.add_argument('--dataset_name', type=str, default='megadepth', help='dataset name')
26
+ data_arg.add_argument('--shuffle_data', type=str2bool, default=True, help='use sequence dataset or shuffled dataset')
27
+ data_arg.add_argument('--use_ram', type=str2bool, default=False, help='load image/depth/pcd to ram')
28
+ data_arg.add_argument('--info_level', choices=['rgb', 'rgbd'], type=str, default='rgbd', help='the information level of dataset')
29
+ data_arg.add_argument('--scene_file', type=str, default=None, required=False, help='what scene/seq want to use')
30
+ data_arg.add_argument('--workers', type=int, default=0, help='worker for loading data')
31
+ data_arg.add_argument('--crop_cam', choices=['no_crop', 'crop_center', 'crop_center_and_resize'], type=str, default='crop_center_and_resize', help='crop the center of image to avoid changing aspect ratio, resize to make the operations batch-able.')
32
+
33
+
34
+ def set_nn_arguments(parser):
35
+ nn_arg = parser.add_argument_group('Nearest neighbors')
36
+ nn_arg.add_argument('--nn_method', choices=['netvlad', 'overlapping'], type=str, default='overlapping', help='how to select nearest neighbors')
37
+ nn_arg.add_argument('--pool_size', type=int, default=20, help='a pool of sorted nn candidates')
38
+ nn_arg.add_argument('--k_size', type=int, default=1, help='select the nn randomly from pool')
39
+
40
+
41
+ def set_COTR_arguments(parser):
42
+ cotr_arg = parser.add_argument_group('COTR model')
43
+ cotr_arg.add_argument('--backbone', type=str, default='resnet50')
44
+ cotr_arg.add_argument('--hidden_dim', type=int, default=256)
45
+ cotr_arg.add_argument('--dilation', type=str2bool, default=False)
46
+ cotr_arg.add_argument('--dropout', type=float, default=0.1)
47
+ cotr_arg.add_argument('--nheads', type=int, default=8)
48
+ cotr_arg.add_argument('--layer', type=str, default='layer3', help='which layer from resnet')
49
+ cotr_arg.add_argument('--enc_layers', type=int, default=6)
50
+ cotr_arg.add_argument('--dec_layers', type=int, default=6)
51
+ cotr_arg.add_argument('--position_embedding', type=str, default='lin_sine', help='sine wave type')
52
+
third_party/COTR/COTR/options/options_utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''utils for argparse
2
+ '''
3
+
4
+ import sys
5
+ import os
6
+ from os import path
7
+ import time
8
+ import json
9
+
10
+ from COTR.utils import utils, debug_utils
11
+ from COTR.global_configs import general_config, dataset_config
12
+
13
+
14
+ def str2bool(v: str) -> bool:
15
+ return v.lower() in ('true', '1', 'yes', 'y', 't')
16
+
17
+
18
+ def get_compact_naming_cotr(opt) -> str:
19
+ base_str = 'model:cotr_{0}_{1}_{2}_dset:{3}_bs:{4}_pe:{5}_lrbackbone:{6}'
20
+ result = base_str.format(opt.backbone,
21
+ opt.layer,
22
+ opt.dim_feedforward,
23
+ opt.dataset_name,
24
+ opt.batch_size,
25
+ opt.position_embedding,
26
+ opt.lr_backbone,
27
+ )
28
+ if opt.suffix:
29
+ result = result + '_suffix:{0}'.format(opt.suffix)
30
+ return result
31
+
32
+
33
+ def print_opt(opt):
34
+ content_list = []
35
+ args = list(vars(opt))
36
+ args.sort()
37
+ for arg in args:
38
+ content_list += [arg.rjust(25, ' ') + ' ' + str(getattr(opt, arg))]
39
+ utils.print_notification(content_list, 'OPTIONS')
40
+
41
+
42
+ def confirm_opt(opt):
43
+ print_opt(opt)
44
+ if opt.use_cc == False:
45
+ if not utils.confirm():
46
+ exit(1)
47
+
48
+
49
+ def opt_to_string(opt) -> str:
50
+ string = '\n\n'
51
+ string += 'python ' + ' '.join(sys.argv)
52
+ string += '\n\n'
53
+ # string += '---------------------- CONFIG ----------------------\n'
54
+ args = list(vars(opt))
55
+ args.sort()
56
+ for arg in args:
57
+ string += arg.rjust(25, ' ') + ' ' + str(getattr(opt, arg)) + '\n\n'
58
+ # string += '----------------------------------------------------\n'
59
+ return string
60
+
61
+
62
+ def save_opt(opt):
63
+ '''save options to a json file
64
+ '''
65
+ if not os.path.exists(opt.out):
66
+ os.makedirs(opt.out)
67
+ json_path = os.path.join(opt.out, 'params.json')
68
+ if 'debug' not in opt.suffix and path.isfile(json_path):
69
+ assert opt.resume, 'You are trying to modify a model without resuming: {0}'.format(opt.out)
70
+ old_dict = json.load(open(json_path))
71
+ new_dict = vars(opt)
72
+ # assert old_dict.keys() == new_dict.keys(), 'New configuration keys is different from old one.\nold: {0}\nnew: {1}'.format(old_dict.keys(), new_dict.keys())
73
+ if new_dict != old_dict:
74
+ exception_keys = ['command']
75
+ for key in set(old_dict.keys()).union(set(new_dict.keys())):
76
+ if key not in exception_keys:
77
+ old_val = old_dict[key] if key in old_dict else 'not exists(old)'
78
+ new_val = new_dict[key] if key in old_dict else 'not exists(new)'
79
+ if old_val != new_val:
80
+ print('key: {0}, old_val: {1}, new_val: {2}'.format(key, old_val, new_val))
81
+ if opt.use_cc == False:
82
+ if not utils.confirm('Please manually confirm'):
83
+ exit(1)
84
+ with open(json_path, 'w') as fp:
85
+ json.dump(vars(opt), fp, indent=0, sort_keys=True)
86
+
87
+
88
+ def build_scenes_name_list_from_opt(opt):
89
+ if hasattr(opt, 'scene_file') and opt.scene_file is not None:
90
+ assert os.path.isfile(opt.scene_file), opt.scene_file
91
+ with open(opt.scene_file, 'r') as f:
92
+ scenes_list = json.load(f)
93
+ else:
94
+ scenes_list = [{'scene': opt.scene, 'seq': opt.seq}]
95
+ if 'megadepth' in opt.dataset_name:
96
+ assert opt.info_level in ['rgb', 'rgbd']
97
+ scenes_name_list = []
98
+ if opt.info_level == 'rgb':
99
+ dir_list = ['scene_dir', 'image_dir']
100
+ elif opt.info_level == 'rgbd':
101
+ dir_list = ['scene_dir', 'image_dir', 'depth_dir']
102
+ dir_list = {dir_name: dataset_config[opt.dataset_name][dir_name] for dir_name in dir_list}
103
+ for item in scenes_list:
104
+ cur_scene = {key: val.format(item['scene'], item['seq']) for key, val in dir_list.items()}
105
+ scenes_name_list.append(cur_scene)
106
+ else:
107
+ raise NotImplementedError()
108
+ return scenes_name_list
third_party/COTR/COTR/projector/pcd_projector.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ a point cloud projector based on np
3
+ '''
4
+
5
+ import numpy as np
6
+
7
+ from COTR.utils import debug_utils, utils
8
+
9
+
10
+ def render_point_cloud_at_capture(point_cloud, capture, render_type='rgb', return_pcd=False):
11
+ assert render_type in ['rgb', 'bw', 'depth']
12
+ if render_type == 'rgb':
13
+ assert point_cloud.shape[1] == 6
14
+ else:
15
+ point_cloud = point_cloud[:, :3]
16
+ assert point_cloud.shape[1] == 3
17
+ if render_type in ['bw', 'rgb']:
18
+ keep_z = False
19
+ else:
20
+ keep_z = True
21
+
22
+ pcd_2d = PointCloudProjector.pcd_3d_to_pcd_2d_np(point_cloud,
23
+ capture.intrinsic_mat,
24
+ capture.extrinsic_mat,
25
+ capture.size,
26
+ keep_z=True,
27
+ crop=True,
28
+ filter_neg=True,
29
+ norm_coord=False,
30
+ return_index=False)
31
+ reproj = PointCloudProjector.pcd_2d_to_img_2d_np(pcd_2d,
32
+ capture.size,
33
+ has_z=True,
34
+ keep_z=keep_z)
35
+ if return_pcd:
36
+ return reproj, pcd_2d
37
+ else:
38
+ return reproj
39
+
40
+
41
+ def optical_flow_from_a_to_b(cap_a, cap_b):
42
+ cap_a_intrinsic = cap_a.pinhole_cam.intrinsic_mat
43
+ cap_a_img_size = cap_a.pinhole_cam.shape[:2]
44
+ _h, _w = cap_b.pinhole_cam.shape[:2]
45
+ x, y = np.meshgrid(
46
+ np.linspace(0, _w - 1, num=_w),
47
+ np.linspace(0, _h - 1, num=_h),
48
+ )
49
+ coord_map = np.concatenate([np.expand_dims(x, 2), np.expand_dims(y, 2)], axis=2)
50
+ pcd_from_cap_b = cap_b.get_point_cloud_world_from_depth(coord_map)
51
+ # pcd_from_cap_b = cap_b.point_cloud_world_w_feat(['pos', 'coord'])
52
+ optical_flow = PointCloudProjector.pcd_2d_to_img_2d_np(PointCloudProjector.pcd_3d_to_pcd_2d_np(pcd_from_cap_b, cap_a_intrinsic, cap_a.cam_pose.world_to_camera[0:3, :], cap_a_img_size, keep_z=True, crop=True, filter_neg=True, norm_coord=False), cap_a_img_size, has_z=True, keep_z=False)
53
+ return optical_flow
54
+
55
+
56
+ class PointCloudProjector():
57
+ def __init__(self):
58
+ pass
59
+
60
+ @staticmethod
61
+ def pcd_2d_to_pcd_3d_np(pcd, depth, intrinsic, motion=None, return_index=False):
62
+ assert isinstance(pcd, np.ndarray), 'cannot process data type: {0}'.format(type(pcd))
63
+ assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic))
64
+ assert len(pcd.shape) == 2 and pcd.shape[1] >= 2
65
+ assert len(depth.shape) == 2 and depth.shape[1] == 1
66
+ assert intrinsic.shape == (3, 3)
67
+ if motion is not None:
68
+ assert isinstance(motion, np.ndarray), 'cannot process data type: {0}'.format(type(motion))
69
+ assert motion.shape == (4, 4)
70
+ # exec(debug_utils.embed_breakpoint())
71
+ x, y, z = pcd[:, 0], pcd[:, 1], depth[:, 0]
72
+ append_ones = np.ones_like(x)
73
+ xyz = np.stack([x, y, append_ones], axis=1) # shape: [num_points, 3]
74
+ inv_intrinsic_mat = np.linalg.inv(intrinsic)
75
+ xyz = np.matmul(inv_intrinsic_mat, xyz.T).T * z[..., None]
76
+ valid_mask_1 = np.where(xyz[:, 2] > 0)
77
+ xyz = xyz[valid_mask_1]
78
+
79
+ if motion is not None:
80
+ append_ones = np.ones_like(xyz[:, 0:1])
81
+ xyzw = np.concatenate([xyz, append_ones], axis=1)
82
+ xyzw = np.matmul(motion, xyzw.T).T
83
+ valid_mask_2 = np.where(xyzw[:, 3] != 0)
84
+ xyzw = xyzw[valid_mask_2]
85
+ xyzw /= xyzw[:, 3:4]
86
+ xyz = xyzw[:, 0:3]
87
+
88
+ if pcd.shape[1] > 2:
89
+ features = pcd[:, 2:]
90
+ try:
91
+ features = features[valid_mask_1][valid_mask_2]
92
+ except UnboundLocalError:
93
+ features = features[valid_mask_1]
94
+ assert xyz.shape[0] == features.shape[0]
95
+ xyz = np.concatenate([xyz, features], axis=1)
96
+ if return_index:
97
+ points_index = np.arange(pcd.shape[0])[valid_mask_1][valid_mask_2]
98
+ return xyz, points_index
99
+ return xyz
100
+
101
+ @staticmethod
102
+ def img_2d_to_pcd_3d_np(depth, intrinsic, img=None, motion=None):
103
+ '''
104
+ the function signature is not fully correct, because img is an optional
105
+ if motion is None, the output pcd is in camera space
106
+ if motion is camera_to_world, the out pcd is in world space.
107
+ here the output is pure np array
108
+ '''
109
+
110
+ assert isinstance(depth, np.ndarray), 'cannot process data type: {0}'.format(type(depth))
111
+ assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic))
112
+ assert len(depth.shape) == 2
113
+ assert intrinsic.shape == (3, 3)
114
+ if img is not None:
115
+ assert isinstance(img, np.ndarray), 'cannot process data type: {0}'.format(type(img))
116
+ assert len(img.shape) == 3
117
+ assert img.shape[:2] == depth.shape[:2], 'feature should have the same resolution as the depth'
118
+ if motion is not None:
119
+ assert isinstance(motion, np.ndarray), 'cannot process data type: {0}'.format(type(motion))
120
+ assert motion.shape == (4, 4)
121
+
122
+ pcd_image_space = PointCloudProjector.img_2d_to_pcd_2d_np(depth[..., None], norm_coord=False)
123
+ valid_mask_1 = np.where(pcd_image_space[:, 2] > 0)
124
+ pcd_image_space = pcd_image_space[valid_mask_1]
125
+ xy = pcd_image_space[:, :2]
126
+ z = pcd_image_space[:, 2:3]
127
+ if img is not None:
128
+ _c = img.shape[-1]
129
+ feat = img.reshape(-1, _c)
130
+ feat = feat[valid_mask_1]
131
+ xy = np.concatenate([xy, feat], axis=1)
132
+ pcd_3d = PointCloudProjector.pcd_2d_to_pcd_3d_np(xy, z, intrinsic, motion=motion)
133
+ return pcd_3d
134
+
135
+ @staticmethod
136
+ def pcd_3d_to_pcd_2d_np(pcd, intrinsic, extrinsic, size, keep_z: bool, crop: bool = True, filter_neg: bool = True, norm_coord: bool = True, return_index: bool = False):
137
+ assert isinstance(pcd, np.ndarray), 'cannot process data type: {0}'.format(type(pcd))
138
+ assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic))
139
+ assert isinstance(extrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(extrinsic))
140
+ assert len(pcd.shape) == 2 and pcd.shape[1] >= 3, 'seems the input pcd is not a valid 3d point cloud: {0}'.format(pcd.shape)
141
+
142
+ xyzw = np.concatenate([pcd[:, 0:3], np.ones_like(pcd[:, 0:1])], axis=1)
143
+ mvp_mat = np.matmul(intrinsic, extrinsic)
144
+ camera_points = np.matmul(mvp_mat, xyzw.T).T
145
+ if filter_neg:
146
+ valid_mask_1 = camera_points[:, 2] > 0.0
147
+ else:
148
+ valid_mask_1 = np.ones_like(camera_points[:, 2], dtype=bool)
149
+ camera_points = camera_points[valid_mask_1]
150
+ image_points = camera_points / camera_points[:, 2:3]
151
+ image_points = image_points[:, :2]
152
+ if crop:
153
+ valid_mask_2 = (image_points[:, 0] >= 0) * (image_points[:, 0] < size[1] - 1) * (image_points[:, 1] >= 0) * (image_points[:, 1] < size[0] - 1)
154
+ else:
155
+ valid_mask_2 = np.ones_like(image_points[:, 0], dtype=bool)
156
+ if norm_coord:
157
+ image_points = ((image_points / size[::-1]) * 2) - 1
158
+
159
+ if keep_z:
160
+ image_points = np.concatenate([image_points[valid_mask_2], camera_points[valid_mask_2][:, 2:3], pcd[valid_mask_1][:, 3:][valid_mask_2]], axis=1)
161
+ else:
162
+ image_points = np.concatenate([image_points[valid_mask_2], pcd[valid_mask_1][:, 3:][valid_mask_2]], axis=1)
163
+ # if filter_neg and crop:
164
+ # exec(debug_utils.embed_breakpoint('pcd_3d_to_pcd_2d_np'))
165
+ if return_index:
166
+ points_index = np.arange(pcd.shape[0])[valid_mask_1][valid_mask_2]
167
+ return image_points, points_index
168
+ return image_points
169
+
170
+ @staticmethod
171
+ def pcd_2d_to_img_2d_np(pcd, size, has_z=False, keep_z=False):
172
+ assert len(pcd.shape) == 2 and pcd.shape[-1] >= 2, 'seems the input pcd is not a valid point cloud: {0}'.format(pcd.shape)
173
+ # assert 0, 'pass Z values in'
174
+ if has_z:
175
+ pcd = pcd[pcd[:, 2].argsort()[::-1]]
176
+ if not keep_z:
177
+ pcd = np.delete(pcd, [2], axis=1)
178
+ index_list = np.round(pcd[:, 0:2]).astype(np.int32)
179
+ index_list[:, 0] = np.clip(index_list[:, 0], 0, size[1] - 1)
180
+ index_list[:, 1] = np.clip(index_list[:, 1], 0, size[0] - 1)
181
+ _h, _w, _c = *size, pcd.shape[-1] - 2
182
+ if _c == 0:
183
+ canvas = np.zeros((_h, _w, 1))
184
+ canvas[index_list[:, 1], index_list[:, 0]] = 1.0
185
+ else:
186
+ canvas = np.zeros((_h, _w, _c))
187
+ canvas[index_list[:, 1], index_list[:, 0]] = pcd[:, 2:]
188
+
189
+ return canvas
190
+
191
+ @staticmethod
192
+ def img_2d_to_pcd_2d_np(img, norm_coord=True):
193
+ assert isinstance(img, np.ndarray), 'cannot process data type: {0}'.format(type(img))
194
+ assert len(img.shape) == 3
195
+
196
+ _h, _w, _c = img.shape
197
+ if norm_coord:
198
+ x, y = np.meshgrid(
199
+ np.linspace(-1, 1, num=_w),
200
+ np.linspace(-1, 1, num=_h),
201
+ )
202
+ else:
203
+ x, y = np.meshgrid(
204
+ np.linspace(0, _w - 1, num=_w),
205
+ np.linspace(0, _h - 1, num=_h),
206
+ )
207
+ x, y = x.reshape(-1, 1), y.reshape(-1, 1)
208
+ feat = img.reshape(-1, _c)
209
+ pcd_2d = np.concatenate([x, y, feat], axis=1)
210
+ return pcd_2d
third_party/COTR/COTR/sfm_scenes/knn_search.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Given one capture in a scene, search for its KNN captures
3
+ '''
4
+
5
+ import os
6
+
7
+ import numpy as np
8
+
9
+ from COTR.utils import debug_utils
10
+ from COTR.utils.constants import VALID_NN_OVERLAPPING_THRESH
11
+
12
+
13
+ class ReprojRatioKnnSearch():
14
+ def __init__(self, scene):
15
+ self.scene = scene
16
+ self.distance_mat = None
17
+ self.nn_index = None
18
+ self._read_dist_mat()
19
+ self._build_nn_index()
20
+
21
+ def _read_dist_mat(self):
22
+ dist_mat_path = os.path.join(os.path.dirname(os.path.dirname(self.scene.captures[0].depth_path)), 'dist_mat/dist_mat.npy')
23
+ self.distance_mat = np.load(dist_mat_path)
24
+
25
+ def _build_nn_index(self):
26
+ # argsort is in ascending order, so we take negative
27
+ self.nn_index = (-1 * self.distance_mat).argsort(axis=1)
28
+
29
+ def get_knn(self, query, k, db_mask=None):
30
+ query_index = self.scene.img_path_to_index_dict[query.img_path]
31
+ if db_mask is not None:
32
+ query_mask = np.setdiff1d(np.arange(self.distance_mat[query_index].shape[0]), db_mask)
33
+ num_pos = (self.distance_mat[query_index] > VALID_NN_OVERLAPPING_THRESH).sum() if db_mask is None else (self.distance_mat[query_index][db_mask] > VALID_NN_OVERLAPPING_THRESH).sum()
34
+ # we have enough valid NN or not
35
+ if num_pos > k:
36
+ if db_mask is None:
37
+ ind = self.nn_index[query_index][:k + 1]
38
+ else:
39
+ temp_dist = self.distance_mat[query_index].copy()
40
+ temp_dist[query_mask] = -1
41
+ ind = (-1 * temp_dist).argsort(axis=0)[:k + 1]
42
+ # remove self
43
+ if query_index in ind:
44
+ ind = np.delete(ind, np.argwhere(ind == query_index))
45
+ else:
46
+ ind = ind[:k]
47
+ assert ind.shape[0] <= k, ind.shape[0] > 0
48
+ else:
49
+ k = num_pos
50
+ if db_mask is None:
51
+ ind = self.nn_index[query_index][:max(k, 1)]
52
+ else:
53
+ temp_dist = self.distance_mat[query_index].copy()
54
+ temp_dist[query_mask] = -1
55
+ ind = (-1 * temp_dist).argsort(axis=0)[:max(k, 1)]
56
+ return self.scene.get_captures_given_index_list(ind)
third_party/COTR/COTR/sfm_scenes/sfm_scenes.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Scene reconstructed from SFM, mainly colmap
3
+ '''
4
+ import os
5
+ import copy
6
+ import math
7
+
8
+ import numpy as np
9
+ from numpy.linalg import inv
10
+ from tqdm import tqdm
11
+
12
+ from COTR.transformations import transformations
13
+ from COTR.transformations.transform_basics import Translation, Rotation
14
+ from COTR.cameras.camera_pose import CameraPose
15
+ from COTR.utils import debug_utils
16
+
17
+
18
+ class SfmScene():
19
+ def __init__(self, captures, point_cloud=None):
20
+ self.captures = captures
21
+ if isinstance(point_cloud, tuple):
22
+ self.point_cloud = point_cloud[0]
23
+ self.point_meta = point_cloud[1]
24
+ else:
25
+ self.point_cloud = point_cloud
26
+ self.img_path_to_index_dict = {}
27
+ self.img_id_to_index_dict = {}
28
+ self.fname_to_index_dict = {}
29
+ self._build_img_X_to_index_dict()
30
+
31
+ def __str__(self):
32
+ string = 'Scene contains {0} captures'.format(len(self.captures))
33
+ return string
34
+
35
+ def __getitem__(self, x):
36
+ if isinstance(x, str):
37
+ try:
38
+ return self.captures[self.img_path_to_index_dict[x]]
39
+ except:
40
+ return self.captures[self.fname_to_index_dict[x]]
41
+ else:
42
+ return self.captures[x]
43
+
44
+ def _build_img_X_to_index_dict(self):
45
+ assert self.captures is not None, 'There is no captures'
46
+ for i, cap in enumerate(self.captures):
47
+ assert cap.img_path not in self.img_path_to_index_dict, 'Image already exists'
48
+ self.img_path_to_index_dict[cap.img_path] = i
49
+ assert os.path.basename(cap.img_path) not in self.fname_to_index_dict, 'Image already exists'
50
+ self.fname_to_index_dict[os.path.basename(cap.img_path)] = i
51
+ if hasattr(cap, 'image_id'):
52
+ self.img_id_to_index_dict[cap.image_id] = i
53
+
54
+ def get_captures_given_index_list(self, index_list):
55
+ captures_list = []
56
+ for i in index_list:
57
+ captures_list.append(self.captures[i])
58
+ return captures_list
59
+
60
+ def get_covisible_caps(self, cap):
61
+ assert cap.img_path in self.img_path_to_index_dict
62
+ covis_img_id = set()
63
+ point_ids = cap.point3d_id
64
+ for i in point_ids:
65
+ covis_img_id = covis_img_id.union(set(self.point_meta[i].image_ids))
66
+ covis_caps = []
67
+ for i in covis_img_id:
68
+ if i in self.img_id_to_index_dict:
69
+ covis_caps.append(self.captures[self.img_id_to_index_dict[i]])
70
+ else:
71
+ pass
72
+ return covis_caps
73
+
74
+ def read_data_to_ram(self, data_list):
75
+ print('warning: you are going to use a lot of RAM.')
76
+ sum_bytes = 0.0
77
+ pbar = tqdm(self.captures, desc='reading data, memory usage {0:.2f} MB'.format(sum_bytes / (1024.0 * 1024.0)))
78
+ for cap in pbar:
79
+ if 'image' in data_list:
80
+ sum_bytes += cap.read_image_to_ram()
81
+ if 'depth' in data_list:
82
+ sum_bytes += cap.read_depth_to_ram()
83
+ if 'pcd' in data_list:
84
+ sum_bytes += cap.read_pcd_to_ram()
85
+ pbar.set_description('reading data, memory usage {0:.2f} MB'.format(sum_bytes / (1024.0 * 1024.0)))
86
+ print('----- total memory usage for images: {0} MB-----'.format(sum_bytes / (1024.0 * 1024.0)))
87
+
third_party/COTR/COTR/trainers/base_trainer.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import abc
4
+ import time
5
+
6
+ import tqdm
7
+ import torch.nn as nn
8
+ import tensorboardX
9
+
10
+ from COTR.trainers import tensorboard_helper
11
+ from COTR.utils import utils
12
+ from COTR.options import options_utils
13
+
14
+
15
+ class BaseTrainer(abc.ABC):
16
+ '''base trainer class.
17
+ contains methods for training, validation, and writing output.
18
+ '''
19
+
20
+ def __init__(self, opt, model, optimizer, criterion,
21
+ train_loader, val_loader):
22
+ self.opt = opt
23
+ self.use_cuda = opt.use_cuda
24
+ self.model = model
25
+ self.optim = optimizer
26
+ self.criterion = criterion
27
+ self.train_loader = train_loader
28
+ self.val_loader = val_loader
29
+ self.out = opt.out
30
+ if not os.path.exists(opt.out):
31
+ os.makedirs(opt.out)
32
+ self.epoch = 0
33
+ self.iteration = 0
34
+ self.max_iter = opt.max_iter
35
+ self.valid_iter = opt.valid_iter
36
+ self.tb_pusher = tensorboard_helper.TensorboardPusher(opt)
37
+ self.push_opt_to_tb()
38
+ self.need_resume = opt.resume
39
+ if self.need_resume:
40
+ self.resume()
41
+ if self.opt.load_weights:
42
+ self.load_pretrained_weights()
43
+
44
+ def push_opt_to_tb(self):
45
+ opt_str = options_utils.opt_to_string(self.opt)
46
+ tb_datapack = tensorboard_helper.TensorboardDatapack()
47
+ tb_datapack.set_training(False)
48
+ tb_datapack.set_iteration(self.iteration)
49
+ tb_datapack.add_text({'options': opt_str})
50
+ self.tb_pusher.push_to_tensorboard(tb_datapack)
51
+
52
+ @abc.abstractmethod
53
+ def validate_batch(self, data_pack):
54
+ pass
55
+
56
+ @abc.abstractmethod
57
+ def validate(self):
58
+ pass
59
+
60
+ @abc.abstractmethod
61
+ def train_batch(self, data_pack):
62
+ '''train for one batch of data
63
+ '''
64
+ pass
65
+
66
+ def train_epoch(self):
67
+ '''train for one epoch
68
+ one epoch is iterating the whole training dataset once
69
+ '''
70
+ self.model.train()
71
+ for batch_idx, data_pack in tqdm.tqdm(enumerate(self.train_loader),
72
+ initial=self.iteration % len(
73
+ self.train_loader),
74
+ total=len(self.train_loader),
75
+ desc='Train epoch={0}'.format(
76
+ self.epoch),
77
+ ncols=80,
78
+ leave=True,
79
+ ):
80
+
81
+ # iteration = batch_idx + self.epoch * len(self.train_loader)
82
+ # if self.iteration != 0 and (iteration - 1) != self.iteration:
83
+ # continue # for resuming
84
+ # self.iteration = iteration
85
+ # self.iteration += 1
86
+ if self.iteration % self.valid_iter == 0:
87
+ time.sleep(2) # Prevent possible deadlock during epoch transition
88
+ self.validate()
89
+ self.train_batch(data_pack)
90
+
91
+ if self.iteration >= self.max_iter:
92
+ break
93
+ self.iteration += 1
94
+
95
+ def train(self):
96
+ '''entrance of the whole training process
97
+ '''
98
+ max_epoch = int(math.ceil(1. * self.max_iter / len(self.train_loader)))
99
+ for epoch in tqdm.trange(self.epoch,
100
+ max_epoch,
101
+ desc='Train',
102
+ ncols=80):
103
+ self.epoch = epoch
104
+ time.sleep(2) # Prevent possible deadlock during epoch transition
105
+ self.train_epoch()
106
+ if self.iteration >= self.max_iter:
107
+ break
108
+
109
+ @abc.abstractmethod
110
+ def resume(self):
111
+ pass
third_party/COTR/COTR/trainers/cotr_trainer.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import os.path as osp
4
+ import time
5
+
6
+ import tqdm
7
+ import torch
8
+ import numpy as np
9
+ import torchvision.utils as vutils
10
+ from PIL import Image, ImageDraw
11
+
12
+
13
+ from COTR.utils import utils, debug_utils, constants
14
+ from COTR.trainers import base_trainer, tensorboard_helper
15
+ from COTR.projector import pcd_projector
16
+
17
+
18
+ class COTRTrainer(base_trainer.BaseTrainer):
19
+ def __init__(self, opt, model, optimizer, criterion,
20
+ train_loader, val_loader):
21
+ super().__init__(opt, model, optimizer, criterion,
22
+ train_loader, val_loader)
23
+
24
+ def validate_batch(self, data_pack):
25
+ assert self.model.training is False
26
+ with torch.no_grad():
27
+ img = data_pack['image'].cuda()
28
+ query = data_pack['queries'].cuda()
29
+ target = data_pack['targets'].cuda()
30
+ self.optim.zero_grad()
31
+ pred = self.model(img, query)['pred_corrs']
32
+ loss = torch.nn.functional.mse_loss(pred, target)
33
+ if self.opt.cycle_consis and self.opt.bidirectional:
34
+ cycle = self.model(img, pred)['pred_corrs']
35
+ mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
36
+ if mask.sum() > 0:
37
+ cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
38
+ loss += cycle_loss
39
+ elif self.opt.cycle_consis and not self.opt.bidirectional:
40
+ img_reverse = torch.cat([img[..., constants.MAX_SIZE:], img[..., :constants.MAX_SIZE]], axis=-1)
41
+ query_reverse = pred.clone()
42
+ query_reverse[..., 0] = query_reverse[..., 0] - 0.5
43
+ cycle = self.model(img_reverse, query_reverse)['pred_corrs']
44
+ cycle[..., 0] = cycle[..., 0] - 0.5
45
+ mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
46
+ if mask.sum() > 0:
47
+ cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
48
+ loss += cycle_loss
49
+ loss_data = loss.data.item()
50
+ if np.isnan(loss_data):
51
+ print('loss is nan while validating')
52
+ return loss_data, pred
53
+
54
+ def validate(self):
55
+ '''validate for whole validation dataset
56
+ '''
57
+ training = self.model.training
58
+ self.model.eval()
59
+ val_loss_list = []
60
+ for batch_idx, data_pack in tqdm.tqdm(
61
+ enumerate(self.val_loader), total=len(self.val_loader),
62
+ desc='Valid iteration=%d' % self.iteration, ncols=80,
63
+ leave=False):
64
+ loss_data, pred = self.validate_batch(data_pack)
65
+ val_loss_list.append(loss_data)
66
+ mean_loss = np.array(val_loss_list).mean()
67
+ validation_data = {'val_loss': mean_loss,
68
+ 'pred': pred,
69
+ }
70
+ self.push_validation_data(data_pack, validation_data)
71
+ self.save_model()
72
+ if training:
73
+ self.model.train()
74
+
75
+ def save_model(self):
76
+ torch.save({
77
+ 'epoch': self.epoch,
78
+ 'iteration': self.iteration,
79
+ 'optim_state_dict': self.optim.state_dict(),
80
+ 'model_state_dict': self.model.state_dict(),
81
+ }, osp.join(self.out, 'checkpoint.pth.tar'))
82
+ if self.iteration % (10 * self.valid_iter) == 0:
83
+ torch.save({
84
+ 'epoch': self.epoch,
85
+ 'iteration': self.iteration,
86
+ 'optim_state_dict': self.optim.state_dict(),
87
+ 'model_state_dict': self.model.state_dict(),
88
+ }, osp.join(self.out, f'{self.iteration}_checkpoint.pth.tar'))
89
+
90
+ def draw_corrs(self, imgs, corrs, col=(255, 0, 0)):
91
+ imgs = utils.torch_img_to_np_img(imgs)
92
+ out = []
93
+ for img, corr in zip(imgs, corrs):
94
+ img = np.interp(img, [img.min(), img.max()], [0, 255]).astype(np.uint8)
95
+ img = Image.fromarray(img)
96
+ draw = ImageDraw.Draw(img)
97
+ corr *= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE])
98
+ for c in corr:
99
+ draw.line(c, fill=col)
100
+ out.append(np.array(img))
101
+ out = np.array(out) / 255.0
102
+ return utils.np_img_to_torch_img(out)
103
+
104
+ def push_validation_data(self, data_pack, validation_data):
105
+ val_loss = validation_data['val_loss']
106
+ pred_corrs = np.concatenate([data_pack['queries'].numpy(), validation_data['pred'].cpu().numpy()], axis=-1)
107
+ pred_corrs = self.draw_corrs(data_pack['image'], pred_corrs)
108
+ gt_corrs = np.concatenate([data_pack['queries'].numpy(), data_pack['targets'].cpu().numpy()], axis=-1)
109
+ gt_corrs = self.draw_corrs(data_pack['image'], gt_corrs, (0, 255, 0))
110
+
111
+ gt_img = vutils.make_grid(gt_corrs, normalize=True, scale_each=True)
112
+ pred_img = vutils.make_grid(pred_corrs, normalize=True, scale_each=True)
113
+ tb_datapack = tensorboard_helper.TensorboardDatapack()
114
+ tb_datapack.set_training(False)
115
+ tb_datapack.set_iteration(self.iteration)
116
+ tb_datapack.add_scalar({'loss/val': val_loss})
117
+ tb_datapack.add_image({'image/gt_corrs': gt_img})
118
+ tb_datapack.add_image({'image/pred_corrs': pred_img})
119
+ self.tb_pusher.push_to_tensorboard(tb_datapack)
120
+
121
+ def train_batch(self, data_pack):
122
+ '''train for one batch of data
123
+ '''
124
+ img = data_pack['image'].cuda()
125
+ query = data_pack['queries'].cuda()
126
+ target = data_pack['targets'].cuda()
127
+
128
+ self.optim.zero_grad()
129
+ pred = self.model(img, query)['pred_corrs']
130
+ loss = torch.nn.functional.mse_loss(pred, target)
131
+ if self.opt.cycle_consis and self.opt.bidirectional:
132
+ cycle = self.model(img, pred)['pred_corrs']
133
+ mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
134
+ if mask.sum() > 0:
135
+ cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
136
+ loss += cycle_loss
137
+ elif self.opt.cycle_consis and not self.opt.bidirectional:
138
+ img_reverse = torch.cat([img[..., constants.MAX_SIZE:], img[..., :constants.MAX_SIZE]], axis=-1)
139
+ query_reverse = pred.clone()
140
+ query_reverse[..., 0] = query_reverse[..., 0] - 0.5
141
+ cycle = self.model(img_reverse, query_reverse)['pred_corrs']
142
+ cycle[..., 0] = cycle[..., 0] - 0.5
143
+ mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
144
+ if mask.sum() > 0:
145
+ cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
146
+ loss += cycle_loss
147
+ loss_data = loss.data.item()
148
+ if np.isnan(loss_data):
149
+ print('loss is nan during training')
150
+ self.optim.zero_grad()
151
+ else:
152
+ loss.backward()
153
+ self.push_training_data(data_pack, pred, target, loss)
154
+ self.optim.step()
155
+
156
+ def push_training_data(self, data_pack, pred, target, loss):
157
+ tb_datapack = tensorboard_helper.TensorboardDatapack()
158
+ tb_datapack.set_training(True)
159
+ tb_datapack.set_iteration(self.iteration)
160
+ tb_datapack.add_histogram({'distribution/pred': pred})
161
+ tb_datapack.add_histogram({'distribution/target': target})
162
+ tb_datapack.add_scalar({'loss/train': loss})
163
+ self.tb_pusher.push_to_tensorboard(tb_datapack)
164
+
165
+ def resume(self):
166
+ '''resume training:
167
+ resume from the recorded epoch, iteration, and saved weights.
168
+ resume from the model with the same name.
169
+
170
+ Arguments:
171
+ opt {[type]} -- [description]
172
+ '''
173
+ if hasattr(self.opt, 'load_weights'):
174
+ assert self.opt.load_weights is None or self.opt.load_weights == False
175
+ # 1. load check point
176
+ checkpoint_path = os.path.join(self.opt.out, 'checkpoint.pth.tar')
177
+ if os.path.isfile(checkpoint_path):
178
+ checkpoint = torch.load(checkpoint_path)
179
+ else:
180
+ raise FileNotFoundError(
181
+ 'model check point cannnot found: {0}'.format(checkpoint_path))
182
+ # 2. load data
183
+ self.epoch = checkpoint['epoch']
184
+ self.iteration = checkpoint['iteration']
185
+ self.load_pretrained_weights()
186
+ self.optim.load_state_dict(checkpoint['optim_state_dict'])
187
+
188
+ def load_pretrained_weights(self):
189
+ '''
190
+ load pretrained weights from another model
191
+ '''
192
+ # if hasattr(self.opt, 'resume'):
193
+ # assert self.opt.resume is False
194
+ assert os.path.isfile(self.opt.load_weights_path), self.opt.load_weights_path
195
+
196
+ saved_weights = torch.load(self.opt.load_weights_path)['model_state_dict']
197
+ utils.safe_load_weights(self.model, saved_weights)
198
+ content_list = []
199
+ content_list += [f'Loaded pretrained weights from {self.opt.load_weights_path}']
200
+ utils.print_notification(content_list)
third_party/COTR/COTR/trainers/tensorboard_helper.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import tensorboardX
4
+
5
+
6
+ class TensorboardDatapack():
7
+ '''data dictionary for pushing to tb
8
+ '''
9
+
10
+ def __init__(self):
11
+ self.SCALAR_NAME = 'scalar'
12
+ self.HISTOGRAM_NAME = 'histogram'
13
+ self.IMAGE_NAME = 'image'
14
+ self.TEXT_NAME = 'text'
15
+ self.datapack = {}
16
+ self.datapack[self.SCALAR_NAME] = {}
17
+ self.datapack[self.HISTOGRAM_NAME] = {}
18
+ self.datapack[self.IMAGE_NAME] = {}
19
+ self.datapack[self.TEXT_NAME] = {}
20
+
21
+ def set_training(self, training):
22
+ self.training = training
23
+
24
+ def set_iteration(self, iteration):
25
+ self.iteration = iteration
26
+
27
+ def add_scalar(self, scalar_dict):
28
+ self.datapack[self.SCALAR_NAME].update(scalar_dict)
29
+
30
+ def add_histogram(self, histogram_dict):
31
+ self.datapack[self.HISTOGRAM_NAME].update(histogram_dict)
32
+
33
+ def add_image(self, image_dict):
34
+ self.datapack[self.IMAGE_NAME].update(image_dict)
35
+
36
+ def add_text(self, text_dict):
37
+ self.datapack[self.TEXT_NAME].update(text_dict)
38
+
39
+
40
+ class TensorboardHelperBase(abc.ABC):
41
+ '''abstract base class for tb helpers
42
+ '''
43
+
44
+ def __init__(self, tb_writer):
45
+ self.tb_writer = tb_writer
46
+
47
+ @abc.abstractmethod
48
+ def add_data(self, tb_datapack):
49
+ pass
50
+
51
+
52
+ class TensorboardScalarHelper(TensorboardHelperBase):
53
+ def add_data(self, tb_datapack):
54
+ scalar_dict = tb_datapack.datapack[tb_datapack.SCALAR_NAME]
55
+ for key, val in scalar_dict.items():
56
+ self.tb_writer.add_scalar(
57
+ key, val, global_step=tb_datapack.iteration)
58
+
59
+
60
+ class TensorboardHistogramHelper(TensorboardHelperBase):
61
+ def add_data(self, tb_datapack):
62
+ histogram_dict = tb_datapack.datapack[tb_datapack.HISTOGRAM_NAME]
63
+ for key, val in histogram_dict.items():
64
+ self.tb_writer.add_histogram(
65
+ key, val, global_step=tb_datapack.iteration)
66
+
67
+
68
+ class TensorboardImageHelper(TensorboardHelperBase):
69
+ def add_data(self, tb_datapack):
70
+ image_dict = tb_datapack.datapack[tb_datapack.IMAGE_NAME]
71
+ for key, val in image_dict.items():
72
+ self.tb_writer.add_image(
73
+ key, val, global_step=tb_datapack.iteration)
74
+
75
+
76
+ class TensorboardTextHelper(TensorboardHelperBase):
77
+ def add_data(self, tb_datapack):
78
+ text_dict = tb_datapack.datapack[tb_datapack.TEXT_NAME]
79
+ for key, val in text_dict.items():
80
+ self.tb_writer.add_text(
81
+ key, val, global_step=tb_datapack.iteration)
82
+
83
+
84
+ class TensorboardPusher():
85
+ def __init__(self, opt):
86
+ self.tb_writer = tensorboardX.SummaryWriter(opt.tb_out)
87
+ scalar_helper = TensorboardScalarHelper(self.tb_writer)
88
+ histogram_helper = TensorboardHistogramHelper(self.tb_writer)
89
+ image_helper = TensorboardImageHelper(self.tb_writer)
90
+ text_helper = TensorboardTextHelper(self.tb_writer)
91
+ self.helper_list = [scalar_helper,
92
+ histogram_helper, image_helper, text_helper]
93
+
94
+ def push_to_tensorboard(self, tb_datapack):
95
+ for helper in self.helper_list:
96
+ helper.add_data(tb_datapack)
97
+ self.tb_writer.flush()
third_party/COTR/COTR/transformations/transform_basics.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from COTR.transformations import transformations
4
+ from COTR.utils import constants
5
+
6
+
7
+ class Rotation():
8
+ def __init__(self, quat):
9
+ """
10
+ quaternion format (w, x, y, z)
11
+ """
12
+ assert quat.dtype == np.float32
13
+ self.quaternion = quat
14
+
15
+ def __str__(self):
16
+ string = '{0}'.format(self.quaternion)
17
+ return string
18
+
19
+ @classmethod
20
+ def from_matrix(cls, mat):
21
+ assert isinstance(mat, np.ndarray)
22
+ if mat.shape == (3, 3):
23
+ id_mat = np.eye(4)
24
+ id_mat[0:3, 0:3] = mat
25
+ mat = id_mat
26
+ assert mat.shape == (4, 4)
27
+ quat = transformations.quaternion_from_matrix(mat).astype(constants.DEFAULT_PRECISION)
28
+ return cls(quat)
29
+
30
+ @property
31
+ def rotation_matrix(self):
32
+ return transformations.quaternion_matrix(self.quaternion).astype(constants.DEFAULT_PRECISION)
33
+
34
+ @rotation_matrix.setter
35
+ def rotation_matrix(self, mat):
36
+ assert isinstance(mat, np.ndarray)
37
+ assert mat.shape == (4, 4)
38
+ quat = transformations.quaternion_from_matrix(mat)
39
+ self.quaternion = quat
40
+
41
+ @property
42
+ def quaternion(self):
43
+ assert isinstance(self._quaternion, np.ndarray)
44
+ assert self._quaternion.shape == (4,)
45
+ assert np.isclose(np.linalg.norm(self._quaternion), 1.0), 'self._quaternion is not normalized or valid'
46
+ return self._quaternion
47
+
48
+ @quaternion.setter
49
+ def quaternion(self, quat):
50
+ assert isinstance(quat, np.ndarray)
51
+ assert quat.shape == (4,)
52
+ if not np.isclose(np.linalg.norm(quat), 1.0):
53
+ print(f'WARNING: normalizing the input quatternion to unit quaternion: {np.linalg.norm(quat)}')
54
+ quat = quat / np.linalg.norm(quat)
55
+ assert np.isclose(np.linalg.norm(quat), 1.0), f'input quaternion is not normalized or valid: {quat}'
56
+ self._quaternion = quat
57
+
58
+
59
+ class UnstableRotation():
60
+ def __init__(self, mat):
61
+ assert isinstance(mat, np.ndarray)
62
+ if mat.shape == (3, 3):
63
+ id_mat = np.eye(4)
64
+ id_mat[0:3, 0:3] = mat
65
+ mat = id_mat
66
+ assert mat.shape == (4, 4)
67
+ mat[:3, 3] = 0
68
+ self._rotation_matrix = mat
69
+
70
+ def __str__(self):
71
+ string = f'rotation_matrix: {self.rotation_matrix}'
72
+ return string
73
+
74
+ @property
75
+ def rotation_matrix(self):
76
+ return self._rotation_matrix
77
+
78
+
79
+ class Translation():
80
+ def __init__(self, vec):
81
+ assert vec.dtype == np.float32
82
+ self.translation_vector = vec
83
+
84
+ def __str__(self):
85
+ string = '{0}'.format(self.translation_vector)
86
+ return string
87
+
88
+ @classmethod
89
+ def from_matrix(cls, mat):
90
+ assert isinstance(mat, np.ndarray)
91
+ assert mat.shape == (4, 4)
92
+ vec = transformations.translation_from_matrix(mat)
93
+ return cls(vec)
94
+
95
+ @property
96
+ def translation_matrix(self):
97
+ return transformations.translation_matrix(self.translation_vector).astype(constants.DEFAULT_PRECISION)
98
+
99
+ @translation_matrix.setter
100
+ def translation_matrix(self, mat):
101
+ assert isinstance(mat, np.ndarray)
102
+ assert mat.shape == (4, 4)
103
+ vec = transformations.translation_from_matrix(mat)
104
+ self.translation_vector = vec
105
+
106
+ @property
107
+ def translation_vector(self):
108
+ return self._translation_vector
109
+
110
+ @translation_vector.setter
111
+ def translation_vector(self, vec):
112
+ assert isinstance(vec, np.ndarray)
113
+ assert vec.shape == (3,)
114
+ self._translation_vector = vec
third_party/COTR/COTR/transformations/transformations.py ADDED
@@ -0,0 +1,1951 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # transformations.py
3
+
4
+ # Copyright (c) 2006-2019, Christoph Gohlke
5
+ # Copyright (c) 2006-2019, The Regents of the University of California
6
+ # Produced at the Laboratory for Fluorescence Dynamics
7
+ # All rights reserved.
8
+ #
9
+ # Redistribution and use in source and binary forms, with or without
10
+ # modification, are permitted provided that the following conditions are met:
11
+ #
12
+ # * Redistributions of source code must retain the above copyright notice,
13
+ # this list of conditions and the following disclaimer.
14
+ #
15
+ # * Redistributions in binary form must reproduce the above copyright notice,
16
+ # this list of conditions and the following disclaimer in the documentation
17
+ # and/or other materials provided with the distribution.
18
+ #
19
+ # * Neither the name of the copyright holder nor the names of its
20
+ # contributors may be used to endorse or promote products derived from
21
+ # this software without specific prior written permission.
22
+ #
23
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
24
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
26
+ # ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
27
+ # LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
28
+ # CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
29
+ # SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
30
+ # INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
31
+ # CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
32
+ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
33
+ # POSSIBILITY OF SUCH DAMAGE.
34
+
35
+ """Homogeneous Transformation Matrices and Quaternions.
36
+
37
+ Transformations is a Python library for calculating 4x4 matrices for
38
+ translating, rotating, reflecting, scaling, shearing, projecting,
39
+ orthogonalizing, and superimposing arrays of 3D homogeneous coordinates
40
+ as well as for converting between rotation matrices, Euler angles,
41
+ and quaternions. Also includes an Arcball control object and
42
+ functions to decompose transformation matrices.
43
+
44
+ :Author:
45
+ `Christoph Gohlke <https://www.lfd.uci.edu/~gohlke/>`_
46
+
47
+ :Organization:
48
+ Laboratory for Fluorescence Dynamics. University of California, Irvine
49
+
50
+ :License: 3-clause BSD
51
+
52
+ :Version: 2019.2.20
53
+
54
+ Requirements
55
+ ------------
56
+ * `CPython 2.7 or 3.5+ <https://www.python.org>`_
57
+ * `Numpy 1.14 <https://www.numpy.org>`_
58
+ * A Python distutils compatible C compiler (build)
59
+
60
+ Revisions
61
+ ---------
62
+ 2019.1.1
63
+ Update copyright year.
64
+
65
+ Notes
66
+ -----
67
+ Transformations.py is no longer actively developed and has a few known issues
68
+ and numerical instabilities. The module is mostly superseded by other modules
69
+ for 3D transformations and quaternions:
70
+
71
+ * `Scipy.spatial.transform <https://github.com/scipy/scipy/tree/master/
72
+ scipy/spatial/transform>`_
73
+ * `Transforms3d <https://github.com/matthew-brett/transforms3d>`_
74
+ (includes most code of this module)
75
+ * `Numpy-quaternion <https://github.com/moble/quaternion>`_
76
+ * `Blender.mathutils <https://docs.blender.org/api/master/mathutils.html>`_
77
+
78
+ The API is not stable yet and is expected to change between revisions.
79
+
80
+ Python 2.7 and 3.4 are deprecated.
81
+
82
+ This Python code is not optimized for speed. Refer to the transformations.c
83
+ module for a faster implementation of some functions.
84
+
85
+ Documentation in HTML format can be generated with epydoc.
86
+
87
+ Matrices (M) can be inverted using numpy.linalg.inv(M), be concatenated using
88
+ numpy.dot(M0, M1), or transform homogeneous coordinate arrays (v) using
89
+ numpy.dot(M, v) for shape (4, \*) column vectors, respectively
90
+ numpy.dot(v, M.T) for shape (\*, 4) row vectors ("array of points").
91
+
92
+ This module follows the "column vectors on the right" and "row major storage"
93
+ (C contiguous) conventions. The translation components are in the right column
94
+ of the transformation matrix, i.e. M[:3, 3].
95
+ The transpose of the transformation matrices may have to be used to interface
96
+ with other graphics systems, e.g. OpenGL's glMultMatrixd(). See also [16].
97
+
98
+ Calculations are carried out with numpy.float64 precision.
99
+
100
+ Vector, point, quaternion, and matrix function arguments are expected to be
101
+ "array like", i.e. tuple, list, or numpy arrays.
102
+
103
+ Return types are numpy arrays unless specified otherwise.
104
+
105
+ Angles are in radians unless specified otherwise.
106
+
107
+ Quaternions w+ix+jy+kz are represented as [w, x, y, z].
108
+
109
+ A triple of Euler angles can be applied/interpreted in 24 ways, which can
110
+ be specified using a 4 character string or encoded 4-tuple:
111
+
112
+ *Axes 4-string*: e.g. 'sxyz' or 'ryxy'
113
+
114
+ - first character : rotations are applied to 's'tatic or 'r'otating frame
115
+ - remaining characters : successive rotation axis 'x', 'y', or 'z'
116
+
117
+ *Axes 4-tuple*: e.g. (0, 0, 0, 0) or (1, 1, 1, 1)
118
+
119
+ - inner axis: code of axis ('x':0, 'y':1, 'z':2) of rightmost matrix.
120
+ - parity : even (0) if inner axis 'x' is followed by 'y', 'y' is followed
121
+ by 'z', or 'z' is followed by 'x'. Otherwise odd (1).
122
+ - repetition : first and last axis are same (1) or different (0).
123
+ - frame : rotations are applied to static (0) or rotating (1) frame.
124
+
125
+ References
126
+ ----------
127
+ (1) Matrices and transformations. Ronald Goldman.
128
+ In "Graphics Gems I", pp 472-475. Morgan Kaufmann, 1990.
129
+ (2) More matrices and transformations: shear and pseudo-perspective.
130
+ Ronald Goldman. In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
131
+ (3) Decomposing a matrix into simple transformations. Spencer Thomas.
132
+ In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
133
+ (4) Recovering the data from the transformation matrix. Ronald Goldman.
134
+ In "Graphics Gems II", pp 324-331. Morgan Kaufmann, 1991.
135
+ (5) Euler angle conversion. Ken Shoemake.
136
+ In "Graphics Gems IV", pp 222-229. Morgan Kaufmann, 1994.
137
+ (6) Arcball rotation control. Ken Shoemake.
138
+ In "Graphics Gems IV", pp 175-192. Morgan Kaufmann, 1994.
139
+ (7) Representing attitude: Euler angles, unit quaternions, and rotation
140
+ vectors. James Diebel. 2006.
141
+ (8) A discussion of the solution for the best rotation to relate two sets
142
+ of vectors. W Kabsch. Acta Cryst. 1978. A34, 827-828.
143
+ (9) Closed-form solution of absolute orientation using unit quaternions.
144
+ BKP Horn. J Opt Soc Am A. 1987. 4(4):629-642.
145
+ (10) Quaternions. Ken Shoemake.
146
+ http://www.sfu.ca/~jwa3/cmpt461/files/quatut.pdf
147
+ (11) From quaternion to matrix and back. JMP van Waveren. 2005.
148
+ http://www.intel.com/cd/ids/developer/asmo-na/eng/293748.htm
149
+ (12) Uniform random rotations. Ken Shoemake.
150
+ In "Graphics Gems III", pp 124-132. Morgan Kaufmann, 1992.
151
+ (13) Quaternion in molecular modeling. CFF Karney.
152
+ J Mol Graph Mod, 25(5):595-604
153
+ (14) New method for extracting the quaternion from a rotation matrix.
154
+ Itzhack Y Bar-Itzhack, J Guid Contr Dynam. 2000. 23(6): 1085-1087.
155
+ (15) Multiple View Geometry in Computer Vision. Hartley and Zissermann.
156
+ Cambridge University Press; 2nd Ed. 2004. Chapter 4, Algorithm 4.7, p 130.
157
+ (16) Column Vectors vs. Row Vectors.
158
+ http://steve.hollasch.net/cgindex/math/matrix/column-vec.html
159
+
160
+ Examples
161
+ --------
162
+ >>> alpha, beta, gamma = 0.123, -1.234, 2.345
163
+ >>> origin, xaxis, yaxis, zaxis = [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]
164
+ >>> I = identity_matrix()
165
+ >>> Rx = rotation_matrix(alpha, xaxis)
166
+ >>> Ry = rotation_matrix(beta, yaxis)
167
+ >>> Rz = rotation_matrix(gamma, zaxis)
168
+ >>> R = concatenate_matrices(Rx, Ry, Rz)
169
+ >>> euler = euler_from_matrix(R, 'rxyz')
170
+ >>> numpy.allclose([alpha, beta, gamma], euler)
171
+ True
172
+ >>> Re = euler_matrix(alpha, beta, gamma, 'rxyz')
173
+ >>> is_same_transform(R, Re)
174
+ True
175
+ >>> al, be, ga = euler_from_matrix(Re, 'rxyz')
176
+ >>> is_same_transform(Re, euler_matrix(al, be, ga, 'rxyz'))
177
+ True
178
+ >>> qx = quaternion_about_axis(alpha, xaxis)
179
+ >>> qy = quaternion_about_axis(beta, yaxis)
180
+ >>> qz = quaternion_about_axis(gamma, zaxis)
181
+ >>> q = quaternion_multiply(qx, qy)
182
+ >>> q = quaternion_multiply(q, qz)
183
+ >>> Rq = quaternion_matrix(q)
184
+ >>> is_same_transform(R, Rq)
185
+ True
186
+ >>> S = scale_matrix(1.23, origin)
187
+ >>> T = translation_matrix([1, 2, 3])
188
+ >>> Z = shear_matrix(beta, xaxis, origin, zaxis)
189
+ >>> R = random_rotation_matrix(numpy.random.rand(3))
190
+ >>> M = concatenate_matrices(T, R, Z, S)
191
+ >>> scale, shear, angles, trans, persp = decompose_matrix(M)
192
+ >>> numpy.allclose(scale, 1.23)
193
+ True
194
+ >>> numpy.allclose(trans, [1, 2, 3])
195
+ True
196
+ >>> numpy.allclose(shear, [0, math.tan(beta), 0])
197
+ True
198
+ >>> is_same_transform(R, euler_matrix(axes='sxyz', *angles))
199
+ True
200
+ >>> M1 = compose_matrix(scale, shear, angles, trans, persp)
201
+ >>> is_same_transform(M, M1)
202
+ True
203
+ >>> v0, v1 = random_vector(3), random_vector(3)
204
+ >>> M = rotation_matrix(angle_between_vectors(v0, v1), vector_product(v0, v1))
205
+ >>> v2 = numpy.dot(v0, M[:3,:3].T)
206
+ >>> numpy.allclose(unit_vector(v1), unit_vector(v2))
207
+ True
208
+
209
+ """
210
+
211
+ from __future__ import division, print_function
212
+
213
+ __version__ = '2019.2.20'
214
+ __docformat__ = 'restructuredtext en'
215
+
216
+ import math
217
+
218
+ import numpy
219
+
220
+
221
+ def identity_matrix():
222
+ """Return 4x4 identity/unit matrix.
223
+
224
+ >>> I = identity_matrix()
225
+ >>> numpy.allclose(I, numpy.dot(I, I))
226
+ True
227
+ >>> numpy.sum(I), numpy.trace(I)
228
+ (4.0, 4.0)
229
+ >>> numpy.allclose(I, numpy.identity(4))
230
+ True
231
+
232
+ """
233
+ return numpy.identity(4)
234
+
235
+
236
+ def translation_matrix(direction):
237
+ """Return matrix to translate by direction vector.
238
+
239
+ >>> v = numpy.random.random(3) - 0.5
240
+ >>> numpy.allclose(v, translation_matrix(v)[:3, 3])
241
+ True
242
+
243
+ """
244
+ M = numpy.identity(4)
245
+ M[:3, 3] = direction[:3]
246
+ return M
247
+
248
+
249
+ def translation_from_matrix(matrix):
250
+ """Return translation vector from translation matrix.
251
+
252
+ >>> v0 = numpy.random.random(3) - 0.5
253
+ >>> v1 = translation_from_matrix(translation_matrix(v0))
254
+ >>> numpy.allclose(v0, v1)
255
+ True
256
+
257
+ """
258
+ return numpy.array(matrix, copy=False)[:3, 3].copy()
259
+
260
+
261
+ def reflection_matrix(point, normal):
262
+ """Return matrix to mirror at plane defined by point and normal vector.
263
+
264
+ >>> v0 = numpy.random.random(4) - 0.5
265
+ >>> v0[3] = 1.
266
+ >>> v1 = numpy.random.random(3) - 0.5
267
+ >>> R = reflection_matrix(v0, v1)
268
+ >>> numpy.allclose(2, numpy.trace(R))
269
+ True
270
+ >>> numpy.allclose(v0, numpy.dot(R, v0))
271
+ True
272
+ >>> v2 = v0.copy()
273
+ >>> v2[:3] += v1
274
+ >>> v3 = v0.copy()
275
+ >>> v2[:3] -= v1
276
+ >>> numpy.allclose(v2, numpy.dot(R, v3))
277
+ True
278
+
279
+ """
280
+ normal = unit_vector(normal[:3])
281
+ M = numpy.identity(4)
282
+ M[:3, :3] -= 2.0 * numpy.outer(normal, normal)
283
+ M[:3, 3] = (2.0 * numpy.dot(point[:3], normal)) * normal
284
+ return M
285
+
286
+
287
+ def reflection_from_matrix(matrix):
288
+ """Return mirror plane point and normal vector from reflection matrix.
289
+
290
+ >>> v0 = numpy.random.random(3) - 0.5
291
+ >>> v1 = numpy.random.random(3) - 0.5
292
+ >>> M0 = reflection_matrix(v0, v1)
293
+ >>> point, normal = reflection_from_matrix(M0)
294
+ >>> M1 = reflection_matrix(point, normal)
295
+ >>> is_same_transform(M0, M1)
296
+ True
297
+
298
+ """
299
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
300
+ # normal: unit eigenvector corresponding to eigenvalue -1
301
+ w, V = numpy.linalg.eig(M[:3, :3])
302
+ i = numpy.where(abs(numpy.real(w) + 1.0) < 1e-8)[0]
303
+ if not len(i):
304
+ raise ValueError('no unit eigenvector corresponding to eigenvalue -1')
305
+ normal = numpy.real(V[:, i[0]]).squeeze()
306
+ # point: any unit eigenvector corresponding to eigenvalue 1
307
+ w, V = numpy.linalg.eig(M)
308
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
309
+ if not len(i):
310
+ raise ValueError('no unit eigenvector corresponding to eigenvalue 1')
311
+ point = numpy.real(V[:, i[-1]]).squeeze()
312
+ point /= point[3]
313
+ return point, normal
314
+
315
+
316
+ def rotation_matrix(angle, direction, point=None):
317
+ """Return matrix to rotate about axis defined by point and direction.
318
+
319
+ >>> R = rotation_matrix(math.pi/2, [0, 0, 1], [1, 0, 0])
320
+ >>> numpy.allclose(numpy.dot(R, [0, 0, 0, 1]), [1, -1, 0, 1])
321
+ True
322
+ >>> angle = (random.random() - 0.5) * (2*math.pi)
323
+ >>> direc = numpy.random.random(3) - 0.5
324
+ >>> point = numpy.random.random(3) - 0.5
325
+ >>> R0 = rotation_matrix(angle, direc, point)
326
+ >>> R1 = rotation_matrix(angle-2*math.pi, direc, point)
327
+ >>> is_same_transform(R0, R1)
328
+ True
329
+ >>> R0 = rotation_matrix(angle, direc, point)
330
+ >>> R1 = rotation_matrix(-angle, -direc, point)
331
+ >>> is_same_transform(R0, R1)
332
+ True
333
+ >>> I = numpy.identity(4, numpy.float64)
334
+ >>> numpy.allclose(I, rotation_matrix(math.pi*2, direc))
335
+ True
336
+ >>> numpy.allclose(2, numpy.trace(rotation_matrix(math.pi/2,
337
+ ... direc, point)))
338
+ True
339
+
340
+ """
341
+ sina = math.sin(angle)
342
+ cosa = math.cos(angle)
343
+ direction = unit_vector(direction[:3])
344
+ # rotation matrix around unit vector
345
+ R = numpy.diag([cosa, cosa, cosa])
346
+ R += numpy.outer(direction, direction) * (1.0 - cosa)
347
+ direction *= sina
348
+ R += numpy.array([[0.0, -direction[2], direction[1]],
349
+ [direction[2], 0.0, -direction[0]],
350
+ [-direction[1], direction[0], 0.0]])
351
+ M = numpy.identity(4)
352
+ M[:3, :3] = R
353
+ if point is not None:
354
+ # rotation not around origin
355
+ point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
356
+ M[:3, 3] = point - numpy.dot(R, point)
357
+ return M
358
+
359
+
360
+ def rotation_from_matrix(matrix):
361
+ """Return rotation angle and axis from rotation matrix.
362
+
363
+ >>> angle = (random.random() - 0.5) * (2*math.pi)
364
+ >>> direc = numpy.random.random(3) - 0.5
365
+ >>> point = numpy.random.random(3) - 0.5
366
+ >>> R0 = rotation_matrix(angle, direc, point)
367
+ >>> angle, direc, point = rotation_from_matrix(R0)
368
+ >>> R1 = rotation_matrix(angle, direc, point)
369
+ >>> is_same_transform(R0, R1)
370
+ True
371
+
372
+ """
373
+ R = numpy.array(matrix, dtype=numpy.float64, copy=False)
374
+ R33 = R[:3, :3]
375
+ # direction: unit eigenvector of R33 corresponding to eigenvalue of 1
376
+ w, W = numpy.linalg.eig(R33.T)
377
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
378
+ if not len(i):
379
+ raise ValueError('no unit eigenvector corresponding to eigenvalue 1')
380
+ direction = numpy.real(W[:, i[-1]]).squeeze()
381
+ # point: unit eigenvector of R33 corresponding to eigenvalue of 1
382
+ w, Q = numpy.linalg.eig(R)
383
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
384
+ if not len(i):
385
+ raise ValueError('no unit eigenvector corresponding to eigenvalue 1')
386
+ point = numpy.real(Q[:, i[-1]]).squeeze()
387
+ point /= point[3]
388
+ # rotation angle depending on direction
389
+ cosa = (numpy.trace(R33) - 1.0) / 2.0
390
+ if abs(direction[2]) > 1e-8:
391
+ sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2]
392
+ elif abs(direction[1]) > 1e-8:
393
+ sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1]
394
+ else:
395
+ sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0]
396
+ angle = math.atan2(sina, cosa)
397
+ return angle, direction, point
398
+
399
+
400
+ def scale_matrix(factor, origin=None, direction=None):
401
+ """Return matrix to scale by factor around origin in direction.
402
+
403
+ Use factor -1 for point symmetry.
404
+
405
+ >>> v = (numpy.random.rand(4, 5) - 0.5) * 20
406
+ >>> v[3] = 1
407
+ >>> S = scale_matrix(-1.234)
408
+ >>> numpy.allclose(numpy.dot(S, v)[:3], -1.234*v[:3])
409
+ True
410
+ >>> factor = random.random() * 10 - 5
411
+ >>> origin = numpy.random.random(3) - 0.5
412
+ >>> direct = numpy.random.random(3) - 0.5
413
+ >>> S = scale_matrix(factor, origin)
414
+ >>> S = scale_matrix(factor, origin, direct)
415
+
416
+ """
417
+ if direction is None:
418
+ # uniform scaling
419
+ M = numpy.diag([factor, factor, factor, 1.0])
420
+ if origin is not None:
421
+ M[:3, 3] = origin[:3]
422
+ M[:3, 3] *= 1.0 - factor
423
+ else:
424
+ # nonuniform scaling
425
+ direction = unit_vector(direction[:3])
426
+ factor = 1.0 - factor
427
+ M = numpy.identity(4)
428
+ M[:3, :3] -= factor * numpy.outer(direction, direction)
429
+ if origin is not None:
430
+ M[:3, 3] = (factor * numpy.dot(origin[:3], direction)) * direction
431
+ return M
432
+
433
+
434
+ def scale_from_matrix(matrix):
435
+ """Return scaling factor, origin and direction from scaling matrix.
436
+
437
+ >>> factor = random.random() * 10 - 5
438
+ >>> origin = numpy.random.random(3) - 0.5
439
+ >>> direct = numpy.random.random(3) - 0.5
440
+ >>> S0 = scale_matrix(factor, origin)
441
+ >>> factor, origin, direction = scale_from_matrix(S0)
442
+ >>> S1 = scale_matrix(factor, origin, direction)
443
+ >>> is_same_transform(S0, S1)
444
+ True
445
+ >>> S0 = scale_matrix(factor, origin, direct)
446
+ >>> factor, origin, direction = scale_from_matrix(S0)
447
+ >>> S1 = scale_matrix(factor, origin, direction)
448
+ >>> is_same_transform(S0, S1)
449
+ True
450
+
451
+ """
452
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
453
+ M33 = M[:3, :3]
454
+ factor = numpy.trace(M33) - 2.0
455
+ try:
456
+ # direction: unit eigenvector corresponding to eigenvalue factor
457
+ w, V = numpy.linalg.eig(M33)
458
+ i = numpy.where(abs(numpy.real(w) - factor) < 1e-8)[0][0]
459
+ direction = numpy.real(V[:, i]).squeeze()
460
+ direction /= vector_norm(direction)
461
+ except IndexError:
462
+ # uniform scaling
463
+ factor = (factor + 2.0) / 3.0
464
+ direction = None
465
+ # origin: any eigenvector corresponding to eigenvalue 1
466
+ w, V = numpy.linalg.eig(M)
467
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
468
+ if not len(i):
469
+ raise ValueError('no eigenvector corresponding to eigenvalue 1')
470
+ origin = numpy.real(V[:, i[-1]]).squeeze()
471
+ origin /= origin[3]
472
+ return factor, origin, direction
473
+
474
+
475
+ def projection_matrix(point, normal, direction=None,
476
+ perspective=None, pseudo=False):
477
+ """Return matrix to project onto plane defined by point and normal.
478
+
479
+ Using either perspective point, projection direction, or none of both.
480
+
481
+ If pseudo is True, perspective projections will preserve relative depth
482
+ such that Perspective = dot(Orthogonal, PseudoPerspective).
483
+
484
+ >>> P = projection_matrix([0, 0, 0], [1, 0, 0])
485
+ >>> numpy.allclose(P[1:, 1:], numpy.identity(4)[1:, 1:])
486
+ True
487
+ >>> point = numpy.random.random(3) - 0.5
488
+ >>> normal = numpy.random.random(3) - 0.5
489
+ >>> direct = numpy.random.random(3) - 0.5
490
+ >>> persp = numpy.random.random(3) - 0.5
491
+ >>> P0 = projection_matrix(point, normal)
492
+ >>> P1 = projection_matrix(point, normal, direction=direct)
493
+ >>> P2 = projection_matrix(point, normal, perspective=persp)
494
+ >>> P3 = projection_matrix(point, normal, perspective=persp, pseudo=True)
495
+ >>> is_same_transform(P2, numpy.dot(P0, P3))
496
+ True
497
+ >>> P = projection_matrix([3, 0, 0], [1, 1, 0], [1, 0, 0])
498
+ >>> v0 = (numpy.random.rand(4, 5) - 0.5) * 20
499
+ >>> v0[3] = 1
500
+ >>> v1 = numpy.dot(P, v0)
501
+ >>> numpy.allclose(v1[1], v0[1])
502
+ True
503
+ >>> numpy.allclose(v1[0], 3-v1[1])
504
+ True
505
+
506
+ """
507
+ M = numpy.identity(4)
508
+ point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
509
+ normal = unit_vector(normal[:3])
510
+ if perspective is not None:
511
+ # perspective projection
512
+ perspective = numpy.array(perspective[:3], dtype=numpy.float64,
513
+ copy=False)
514
+ M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal)
515
+ M[:3, :3] -= numpy.outer(perspective, normal)
516
+ if pseudo:
517
+ # preserve relative depth
518
+ M[:3, :3] -= numpy.outer(normal, normal)
519
+ M[:3, 3] = numpy.dot(point, normal) * (perspective+normal)
520
+ else:
521
+ M[:3, 3] = numpy.dot(point, normal) * perspective
522
+ M[3, :3] = -normal
523
+ M[3, 3] = numpy.dot(perspective, normal)
524
+ elif direction is not None:
525
+ # parallel projection
526
+ direction = numpy.array(direction[:3], dtype=numpy.float64, copy=False)
527
+ scale = numpy.dot(direction, normal)
528
+ M[:3, :3] -= numpy.outer(direction, normal) / scale
529
+ M[:3, 3] = direction * (numpy.dot(point, normal) / scale)
530
+ else:
531
+ # orthogonal projection
532
+ M[:3, :3] -= numpy.outer(normal, normal)
533
+ M[:3, 3] = numpy.dot(point, normal) * normal
534
+ return M
535
+
536
+
537
+ def projection_from_matrix(matrix, pseudo=False):
538
+ """Return projection plane and perspective point from projection matrix.
539
+
540
+ Return values are same as arguments for projection_matrix function:
541
+ point, normal, direction, perspective, and pseudo.
542
+
543
+ >>> point = numpy.random.random(3) - 0.5
544
+ >>> normal = numpy.random.random(3) - 0.5
545
+ >>> direct = numpy.random.random(3) - 0.5
546
+ >>> persp = numpy.random.random(3) - 0.5
547
+ >>> P0 = projection_matrix(point, normal)
548
+ >>> result = projection_from_matrix(P0)
549
+ >>> P1 = projection_matrix(*result)
550
+ >>> is_same_transform(P0, P1)
551
+ True
552
+ >>> P0 = projection_matrix(point, normal, direct)
553
+ >>> result = projection_from_matrix(P0)
554
+ >>> P1 = projection_matrix(*result)
555
+ >>> is_same_transform(P0, P1)
556
+ True
557
+ >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=False)
558
+ >>> result = projection_from_matrix(P0, pseudo=False)
559
+ >>> P1 = projection_matrix(*result)
560
+ >>> is_same_transform(P0, P1)
561
+ True
562
+ >>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=True)
563
+ >>> result = projection_from_matrix(P0, pseudo=True)
564
+ >>> P1 = projection_matrix(*result)
565
+ >>> is_same_transform(P0, P1)
566
+ True
567
+
568
+ """
569
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
570
+ M33 = M[:3, :3]
571
+ w, V = numpy.linalg.eig(M)
572
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
573
+ if not pseudo and len(i):
574
+ # point: any eigenvector corresponding to eigenvalue 1
575
+ point = numpy.real(V[:, i[-1]]).squeeze()
576
+ point /= point[3]
577
+ # direction: unit eigenvector corresponding to eigenvalue 0
578
+ w, V = numpy.linalg.eig(M33)
579
+ i = numpy.where(abs(numpy.real(w)) < 1e-8)[0]
580
+ if not len(i):
581
+ raise ValueError('no eigenvector corresponding to eigenvalue 0')
582
+ direction = numpy.real(V[:, i[0]]).squeeze()
583
+ direction /= vector_norm(direction)
584
+ # normal: unit eigenvector of M33.T corresponding to eigenvalue 0
585
+ w, V = numpy.linalg.eig(M33.T)
586
+ i = numpy.where(abs(numpy.real(w)) < 1e-8)[0]
587
+ if len(i):
588
+ # parallel projection
589
+ normal = numpy.real(V[:, i[0]]).squeeze()
590
+ normal /= vector_norm(normal)
591
+ return point, normal, direction, None, False
592
+ else:
593
+ # orthogonal projection, where normal equals direction vector
594
+ return point, direction, None, None, False
595
+ else:
596
+ # perspective projection
597
+ i = numpy.where(abs(numpy.real(w)) > 1e-8)[0]
598
+ if not len(i):
599
+ raise ValueError(
600
+ 'no eigenvector not corresponding to eigenvalue 0')
601
+ point = numpy.real(V[:, i[-1]]).squeeze()
602
+ point /= point[3]
603
+ normal = - M[3, :3]
604
+ perspective = M[:3, 3] / numpy.dot(point[:3], normal)
605
+ if pseudo:
606
+ perspective -= normal
607
+ return point, normal, None, perspective, pseudo
608
+
609
+
610
+ def clip_matrix(left, right, bottom, top, near, far, perspective=False):
611
+ """Return matrix to obtain normalized device coordinates from frustum.
612
+
613
+ The frustum bounds are axis-aligned along x (left, right),
614
+ y (bottom, top) and z (near, far).
615
+
616
+ Normalized device coordinates are in range [-1, 1] if coordinates are
617
+ inside the frustum.
618
+
619
+ If perspective is True the frustum is a truncated pyramid with the
620
+ perspective point at origin and direction along z axis, otherwise an
621
+ orthographic canonical view volume (a box).
622
+
623
+ Homogeneous coordinates transformed by the perspective clip matrix
624
+ need to be dehomogenized (divided by w coordinate).
625
+
626
+ >>> frustum = numpy.random.rand(6)
627
+ >>> frustum[1] += frustum[0]
628
+ >>> frustum[3] += frustum[2]
629
+ >>> frustum[5] += frustum[4]
630
+ >>> M = clip_matrix(perspective=False, *frustum)
631
+ >>> numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1])
632
+ array([-1., -1., -1., 1.])
633
+ >>> numpy.dot(M, [frustum[1], frustum[3], frustum[5], 1])
634
+ array([ 1., 1., 1., 1.])
635
+ >>> M = clip_matrix(perspective=True, *frustum)
636
+ >>> v = numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1])
637
+ >>> v / v[3]
638
+ array([-1., -1., -1., 1.])
639
+ >>> v = numpy.dot(M, [frustum[1], frustum[3], frustum[4], 1])
640
+ >>> v / v[3]
641
+ array([ 1., 1., -1., 1.])
642
+
643
+ """
644
+ if left >= right or bottom >= top or near >= far:
645
+ raise ValueError('invalid frustum')
646
+ if perspective:
647
+ if near <= _EPS:
648
+ raise ValueError('invalid frustum: near <= 0')
649
+ t = 2.0 * near
650
+ M = [[t/(left-right), 0.0, (right+left)/(right-left), 0.0],
651
+ [0.0, t/(bottom-top), (top+bottom)/(top-bottom), 0.0],
652
+ [0.0, 0.0, (far+near)/(near-far), t*far/(far-near)],
653
+ [0.0, 0.0, -1.0, 0.0]]
654
+ else:
655
+ M = [[2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)],
656
+ [0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)],
657
+ [0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)],
658
+ [0.0, 0.0, 0.0, 1.0]]
659
+ return numpy.array(M)
660
+
661
+
662
+ def shear_matrix(angle, direction, point, normal):
663
+ """Return matrix to shear by angle along direction vector on shear plane.
664
+
665
+ The shear plane is defined by a point and normal vector. The direction
666
+ vector must be orthogonal to the plane's normal vector.
667
+
668
+ A point P is transformed by the shear matrix into P" such that
669
+ the vector P-P" is parallel to the direction vector and its extent is
670
+ given by the angle of P-P'-P", where P' is the orthogonal projection
671
+ of P onto the shear plane.
672
+
673
+ >>> angle = (random.random() - 0.5) * 4*math.pi
674
+ >>> direct = numpy.random.random(3) - 0.5
675
+ >>> point = numpy.random.random(3) - 0.5
676
+ >>> normal = numpy.cross(direct, numpy.random.random(3))
677
+ >>> S = shear_matrix(angle, direct, point, normal)
678
+ >>> numpy.allclose(1, numpy.linalg.det(S))
679
+ True
680
+
681
+ """
682
+ normal = unit_vector(normal[:3])
683
+ direction = unit_vector(direction[:3])
684
+ if abs(numpy.dot(normal, direction)) > 1e-6:
685
+ raise ValueError('direction and normal vectors are not orthogonal')
686
+ angle = math.tan(angle)
687
+ M = numpy.identity(4)
688
+ M[:3, :3] += angle * numpy.outer(direction, normal)
689
+ M[:3, 3] = -angle * numpy.dot(point[:3], normal) * direction
690
+ return M
691
+
692
+
693
+ def shear_from_matrix(matrix):
694
+ """Return shear angle, direction and plane from shear matrix.
695
+
696
+ >>> angle = (random.random() - 0.5) * 4*math.pi
697
+ >>> direct = numpy.random.random(3) - 0.5
698
+ >>> point = numpy.random.random(3) - 0.5
699
+ >>> normal = numpy.cross(direct, numpy.random.random(3))
700
+ >>> S0 = shear_matrix(angle, direct, point, normal)
701
+ >>> angle, direct, point, normal = shear_from_matrix(S0)
702
+ >>> S1 = shear_matrix(angle, direct, point, normal)
703
+ >>> is_same_transform(S0, S1)
704
+ True
705
+
706
+ """
707
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)
708
+ M33 = M[:3, :3]
709
+ # normal: cross independent eigenvectors corresponding to the eigenvalue 1
710
+ w, V = numpy.linalg.eig(M33)
711
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-4)[0]
712
+ if len(i) < 2:
713
+ raise ValueError('no two linear independent eigenvectors found %s' % w)
714
+ V = numpy.real(V[:, i]).squeeze().T
715
+ lenorm = -1.0
716
+ for i0, i1 in ((0, 1), (0, 2), (1, 2)):
717
+ n = numpy.cross(V[i0], V[i1])
718
+ w = vector_norm(n)
719
+ if w > lenorm:
720
+ lenorm = w
721
+ normal = n
722
+ normal /= lenorm
723
+ # direction and angle
724
+ direction = numpy.dot(M33 - numpy.identity(3), normal)
725
+ angle = vector_norm(direction)
726
+ direction /= angle
727
+ angle = math.atan(angle)
728
+ # point: eigenvector corresponding to eigenvalue 1
729
+ w, V = numpy.linalg.eig(M)
730
+ i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
731
+ if not len(i):
732
+ raise ValueError('no eigenvector corresponding to eigenvalue 1')
733
+ point = numpy.real(V[:, i[-1]]).squeeze()
734
+ point /= point[3]
735
+ return angle, direction, point, normal
736
+
737
+
738
+ def decompose_matrix(matrix):
739
+ """Return sequence of transformations from transformation matrix.
740
+
741
+ matrix : array_like
742
+ Non-degenerative homogeneous transformation matrix
743
+
744
+ Return tuple of:
745
+ scale : vector of 3 scaling factors
746
+ shear : list of shear factors for x-y, x-z, y-z axes
747
+ angles : list of Euler angles about static x, y, z axes
748
+ translate : translation vector along x, y, z axes
749
+ perspective : perspective partition of matrix
750
+
751
+ Raise ValueError if matrix is of wrong type or degenerative.
752
+
753
+ >>> T0 = translation_matrix([1, 2, 3])
754
+ >>> scale, shear, angles, trans, persp = decompose_matrix(T0)
755
+ >>> T1 = translation_matrix(trans)
756
+ >>> numpy.allclose(T0, T1)
757
+ True
758
+ >>> S = scale_matrix(0.123)
759
+ >>> scale, shear, angles, trans, persp = decompose_matrix(S)
760
+ >>> scale[0]
761
+ 0.123
762
+ >>> R0 = euler_matrix(1, 2, 3)
763
+ >>> scale, shear, angles, trans, persp = decompose_matrix(R0)
764
+ >>> R1 = euler_matrix(*angles)
765
+ >>> numpy.allclose(R0, R1)
766
+ True
767
+
768
+ """
769
+ M = numpy.array(matrix, dtype=numpy.float64, copy=True).T
770
+ if abs(M[3, 3]) < _EPS:
771
+ raise ValueError('M[3, 3] is zero')
772
+ M /= M[3, 3]
773
+ P = M.copy()
774
+ P[:, 3] = 0.0, 0.0, 0.0, 1.0
775
+ if not numpy.linalg.det(P):
776
+ raise ValueError('matrix is singular')
777
+
778
+ scale = numpy.zeros((3, ))
779
+ shear = [0.0, 0.0, 0.0]
780
+ angles = [0.0, 0.0, 0.0]
781
+
782
+ if any(abs(M[:3, 3]) > _EPS):
783
+ perspective = numpy.dot(M[:, 3], numpy.linalg.inv(P.T))
784
+ M[:, 3] = 0.0, 0.0, 0.0, 1.0
785
+ else:
786
+ perspective = numpy.array([0.0, 0.0, 0.0, 1.0])
787
+
788
+ translate = M[3, :3].copy()
789
+ M[3, :3] = 0.0
790
+
791
+ row = M[:3, :3].copy()
792
+ scale[0] = vector_norm(row[0])
793
+ row[0] /= scale[0]
794
+ shear[0] = numpy.dot(row[0], row[1])
795
+ row[1] -= row[0] * shear[0]
796
+ scale[1] = vector_norm(row[1])
797
+ row[1] /= scale[1]
798
+ shear[0] /= scale[1]
799
+ shear[1] = numpy.dot(row[0], row[2])
800
+ row[2] -= row[0] * shear[1]
801
+ shear[2] = numpy.dot(row[1], row[2])
802
+ row[2] -= row[1] * shear[2]
803
+ scale[2] = vector_norm(row[2])
804
+ row[2] /= scale[2]
805
+ shear[1:] /= scale[2]
806
+
807
+ if numpy.dot(row[0], numpy.cross(row[1], row[2])) < 0:
808
+ numpy.negative(scale, scale)
809
+ numpy.negative(row, row)
810
+
811
+ angles[1] = math.asin(-row[0, 2])
812
+ if math.cos(angles[1]):
813
+ angles[0] = math.atan2(row[1, 2], row[2, 2])
814
+ angles[2] = math.atan2(row[0, 1], row[0, 0])
815
+ else:
816
+ # angles[0] = math.atan2(row[1, 0], row[1, 1])
817
+ angles[0] = math.atan2(-row[2, 1], row[1, 1])
818
+ angles[2] = 0.0
819
+
820
+ return scale, shear, angles, translate, perspective
821
+
822
+
823
+ def compose_matrix(scale=None, shear=None, angles=None, translate=None,
824
+ perspective=None):
825
+ """Return transformation matrix from sequence of transformations.
826
+
827
+ This is the inverse of the decompose_matrix function.
828
+
829
+ Sequence of transformations:
830
+ scale : vector of 3 scaling factors
831
+ shear : list of shear factors for x-y, x-z, y-z axes
832
+ angles : list of Euler angles about static x, y, z axes
833
+ translate : translation vector along x, y, z axes
834
+ perspective : perspective partition of matrix
835
+
836
+ >>> scale = numpy.random.random(3) - 0.5
837
+ >>> shear = numpy.random.random(3) - 0.5
838
+ >>> angles = (numpy.random.random(3) - 0.5) * (2*math.pi)
839
+ >>> trans = numpy.random.random(3) - 0.5
840
+ >>> persp = numpy.random.random(4) - 0.5
841
+ >>> M0 = compose_matrix(scale, shear, angles, trans, persp)
842
+ >>> result = decompose_matrix(M0)
843
+ >>> M1 = compose_matrix(*result)
844
+ >>> is_same_transform(M0, M1)
845
+ True
846
+
847
+ """
848
+ M = numpy.identity(4)
849
+ if perspective is not None:
850
+ P = numpy.identity(4)
851
+ P[3, :] = perspective[:4]
852
+ M = numpy.dot(M, P)
853
+ if translate is not None:
854
+ T = numpy.identity(4)
855
+ T[:3, 3] = translate[:3]
856
+ M = numpy.dot(M, T)
857
+ if angles is not None:
858
+ R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz')
859
+ M = numpy.dot(M, R)
860
+ if shear is not None:
861
+ Z = numpy.identity(4)
862
+ Z[1, 2] = shear[2]
863
+ Z[0, 2] = shear[1]
864
+ Z[0, 1] = shear[0]
865
+ M = numpy.dot(M, Z)
866
+ if scale is not None:
867
+ S = numpy.identity(4)
868
+ S[0, 0] = scale[0]
869
+ S[1, 1] = scale[1]
870
+ S[2, 2] = scale[2]
871
+ M = numpy.dot(M, S)
872
+ M /= M[3, 3]
873
+ return M
874
+
875
+
876
+ def orthogonalization_matrix(lengths, angles):
877
+ """Return orthogonalization matrix for crystallographic cell coordinates.
878
+
879
+ Angles are expected in degrees.
880
+
881
+ The de-orthogonalization matrix is the inverse.
882
+
883
+ >>> O = orthogonalization_matrix([10, 10, 10], [90, 90, 90])
884
+ >>> numpy.allclose(O[:3, :3], numpy.identity(3, float) * 10)
885
+ True
886
+ >>> O = orthogonalization_matrix([9.8, 12.0, 15.5], [87.2, 80.7, 69.7])
887
+ >>> numpy.allclose(numpy.sum(O), 43.063229)
888
+ True
889
+
890
+ """
891
+ a, b, c = lengths
892
+ angles = numpy.radians(angles)
893
+ sina, sinb, _ = numpy.sin(angles)
894
+ cosa, cosb, cosg = numpy.cos(angles)
895
+ co = (cosa * cosb - cosg) / (sina * sinb)
896
+ return numpy.array([
897
+ [a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0],
898
+ [-a*sinb*co, b*sina, 0.0, 0.0],
899
+ [a*cosb, b*cosa, c, 0.0],
900
+ [0.0, 0.0, 0.0, 1.0]])
901
+
902
+
903
+ def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True):
904
+ """Return affine transform matrix to register two point sets.
905
+
906
+ v0 and v1 are shape (ndims, \*) arrays of at least ndims non-homogeneous
907
+ coordinates, where ndims is the dimensionality of the coordinate space.
908
+
909
+ If shear is False, a similarity transformation matrix is returned.
910
+ If also scale is False, a rigid/Euclidean transformation matrix
911
+ is returned.
912
+
913
+ By default the algorithm by Hartley and Zissermann [15] is used.
914
+ If usesvd is True, similarity and Euclidean transformation matrices
915
+ are calculated by minimizing the weighted sum of squared deviations
916
+ (RMSD) according to the algorithm by Kabsch [8].
917
+ Otherwise, and if ndims is 3, the quaternion based algorithm by Horn [9]
918
+ is used, which is slower when using this Python implementation.
919
+
920
+ The returned matrix performs rotation, translation and uniform scaling
921
+ (if specified).
922
+
923
+ >>> v0 = [[0, 1031, 1031, 0], [0, 0, 1600, 1600]]
924
+ >>> v1 = [[675, 826, 826, 677], [55, 52, 281, 277]]
925
+ >>> affine_matrix_from_points(v0, v1)
926
+ array([[ 0.14549, 0.00062, 675.50008],
927
+ [ 0.00048, 0.14094, 53.24971],
928
+ [ 0. , 0. , 1. ]])
929
+ >>> T = translation_matrix(numpy.random.random(3)-0.5)
930
+ >>> R = random_rotation_matrix(numpy.random.random(3))
931
+ >>> S = scale_matrix(random.random())
932
+ >>> M = concatenate_matrices(T, R, S)
933
+ >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20
934
+ >>> v0[3] = 1
935
+ >>> v1 = numpy.dot(M, v0)
936
+ >>> v0[:3] += numpy.random.normal(0, 1e-8, 300).reshape(3, -1)
937
+ >>> M = affine_matrix_from_points(v0[:3], v1[:3])
938
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
939
+ True
940
+
941
+ More examples in superimposition_matrix()
942
+
943
+ """
944
+ v0 = numpy.array(v0, dtype=numpy.float64, copy=True)
945
+ v1 = numpy.array(v1, dtype=numpy.float64, copy=True)
946
+
947
+ ndims = v0.shape[0]
948
+ if ndims < 2 or v0.shape[1] < ndims or v0.shape != v1.shape:
949
+ raise ValueError('input arrays are of wrong shape or type')
950
+
951
+ # move centroids to origin
952
+ t0 = -numpy.mean(v0, axis=1)
953
+ M0 = numpy.identity(ndims+1)
954
+ M0[:ndims, ndims] = t0
955
+ v0 += t0.reshape(ndims, 1)
956
+ t1 = -numpy.mean(v1, axis=1)
957
+ M1 = numpy.identity(ndims+1)
958
+ M1[:ndims, ndims] = t1
959
+ v1 += t1.reshape(ndims, 1)
960
+
961
+ if shear:
962
+ # Affine transformation
963
+ A = numpy.concatenate((v0, v1), axis=0)
964
+ u, s, vh = numpy.linalg.svd(A.T)
965
+ vh = vh[:ndims].T
966
+ B = vh[:ndims]
967
+ C = vh[ndims:2*ndims]
968
+ t = numpy.dot(C, numpy.linalg.pinv(B))
969
+ t = numpy.concatenate((t, numpy.zeros((ndims, 1))), axis=1)
970
+ M = numpy.vstack((t, ((0.0,)*ndims) + (1.0,)))
971
+ elif usesvd or ndims != 3:
972
+ # Rigid transformation via SVD of covariance matrix
973
+ u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T))
974
+ # rotation matrix from SVD orthonormal bases
975
+ R = numpy.dot(u, vh)
976
+ if numpy.linalg.det(R) < 0.0:
977
+ # R does not constitute right handed system
978
+ R -= numpy.outer(u[:, ndims-1], vh[ndims-1, :]*2.0)
979
+ s[-1] *= -1.0
980
+ # homogeneous transformation matrix
981
+ M = numpy.identity(ndims+1)
982
+ M[:ndims, :ndims] = R
983
+ else:
984
+ # Rigid transformation matrix via quaternion
985
+ # compute symmetric matrix N
986
+ xx, yy, zz = numpy.sum(v0 * v1, axis=1)
987
+ xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1)
988
+ xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1)
989
+ N = [[xx+yy+zz, 0.0, 0.0, 0.0],
990
+ [yz-zy, xx-yy-zz, 0.0, 0.0],
991
+ [zx-xz, xy+yx, yy-xx-zz, 0.0],
992
+ [xy-yx, zx+xz, yz+zy, zz-xx-yy]]
993
+ # quaternion: eigenvector corresponding to most positive eigenvalue
994
+ w, V = numpy.linalg.eigh(N)
995
+ q = V[:, numpy.argmax(w)]
996
+ q /= vector_norm(q) # unit quaternion
997
+ # homogeneous transformation matrix
998
+ M = quaternion_matrix(q)
999
+
1000
+ if scale and not shear:
1001
+ # Affine transformation; scale is ratio of RMS deviations from centroid
1002
+ v0 *= v0
1003
+ v1 *= v1
1004
+ M[:ndims, :ndims] *= math.sqrt(numpy.sum(v1) / numpy.sum(v0))
1005
+
1006
+ # move centroids back
1007
+ M = numpy.dot(numpy.linalg.inv(M1), numpy.dot(M, M0))
1008
+ M /= M[ndims, ndims]
1009
+ return M
1010
+
1011
+
1012
+ def superimposition_matrix(v0, v1, scale=False, usesvd=True):
1013
+ """Return matrix to transform given 3D point set into second point set.
1014
+
1015
+ v0 and v1 are shape (3, \*) or (4, \*) arrays of at least 3 points.
1016
+
1017
+ The parameters scale and usesvd are explained in the more general
1018
+ affine_matrix_from_points function.
1019
+
1020
+ The returned matrix is a similarity or Euclidean transformation matrix.
1021
+ This function has a fast C implementation in transformations.c.
1022
+
1023
+ >>> v0 = numpy.random.rand(3, 10)
1024
+ >>> M = superimposition_matrix(v0, v0)
1025
+ >>> numpy.allclose(M, numpy.identity(4))
1026
+ True
1027
+ >>> R = random_rotation_matrix(numpy.random.random(3))
1028
+ >>> v0 = [[1,0,0], [0,1,0], [0,0,1], [1,1,1]]
1029
+ >>> v1 = numpy.dot(R, v0)
1030
+ >>> M = superimposition_matrix(v0, v1)
1031
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
1032
+ True
1033
+ >>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20
1034
+ >>> v0[3] = 1
1035
+ >>> v1 = numpy.dot(R, v0)
1036
+ >>> M = superimposition_matrix(v0, v1)
1037
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
1038
+ True
1039
+ >>> S = scale_matrix(random.random())
1040
+ >>> T = translation_matrix(numpy.random.random(3)-0.5)
1041
+ >>> M = concatenate_matrices(T, R, S)
1042
+ >>> v1 = numpy.dot(M, v0)
1043
+ >>> v0[:3] += numpy.random.normal(0, 1e-9, 300).reshape(3, -1)
1044
+ >>> M = superimposition_matrix(v0, v1, scale=True)
1045
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
1046
+ True
1047
+ >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False)
1048
+ >>> numpy.allclose(v1, numpy.dot(M, v0))
1049
+ True
1050
+ >>> v = numpy.empty((4, 100, 3))
1051
+ >>> v[:, :, 0] = v0
1052
+ >>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False)
1053
+ >>> numpy.allclose(v1, numpy.dot(M, v[:, :, 0]))
1054
+ True
1055
+
1056
+ """
1057
+ v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3]
1058
+ v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3]
1059
+ return affine_matrix_from_points(v0, v1, shear=False,
1060
+ scale=scale, usesvd=usesvd)
1061
+
1062
+
1063
+ def euler_matrix(ai, aj, ak, axes='sxyz'):
1064
+ """Return homogeneous rotation matrix from Euler angles and axis sequence.
1065
+
1066
+ ai, aj, ak : Euler's roll, pitch and yaw angles
1067
+ axes : One of 24 axis sequences as string or encoded tuple
1068
+
1069
+ >>> R = euler_matrix(1, 2, 3, 'syxz')
1070
+ >>> numpy.allclose(numpy.sum(R[0]), -1.34786452)
1071
+ True
1072
+ >>> R = euler_matrix(1, 2, 3, (0, 1, 0, 1))
1073
+ >>> numpy.allclose(numpy.sum(R[0]), -0.383436184)
1074
+ True
1075
+ >>> ai, aj, ak = (4*math.pi) * (numpy.random.random(3) - 0.5)
1076
+ >>> for axes in _AXES2TUPLE.keys():
1077
+ ... R = euler_matrix(ai, aj, ak, axes)
1078
+ >>> for axes in _TUPLE2AXES.keys():
1079
+ ... R = euler_matrix(ai, aj, ak, axes)
1080
+
1081
+ """
1082
+ try:
1083
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes]
1084
+ except (AttributeError, KeyError):
1085
+ _TUPLE2AXES[axes] # noqa: validation
1086
+ firstaxis, parity, repetition, frame = axes
1087
+
1088
+ i = firstaxis
1089
+ j = _NEXT_AXIS[i+parity]
1090
+ k = _NEXT_AXIS[i-parity+1]
1091
+
1092
+ if frame:
1093
+ ai, ak = ak, ai
1094
+ if parity:
1095
+ ai, aj, ak = -ai, -aj, -ak
1096
+
1097
+ si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak)
1098
+ ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak)
1099
+ cc, cs = ci*ck, ci*sk
1100
+ sc, ss = si*ck, si*sk
1101
+
1102
+ M = numpy.identity(4)
1103
+ if repetition:
1104
+ M[i, i] = cj
1105
+ M[i, j] = sj*si
1106
+ M[i, k] = sj*ci
1107
+ M[j, i] = sj*sk
1108
+ M[j, j] = -cj*ss+cc
1109
+ M[j, k] = -cj*cs-sc
1110
+ M[k, i] = -sj*ck
1111
+ M[k, j] = cj*sc+cs
1112
+ M[k, k] = cj*cc-ss
1113
+ else:
1114
+ M[i, i] = cj*ck
1115
+ M[i, j] = sj*sc-cs
1116
+ M[i, k] = sj*cc+ss
1117
+ M[j, i] = cj*sk
1118
+ M[j, j] = sj*ss+cc
1119
+ M[j, k] = sj*cs-sc
1120
+ M[k, i] = -sj
1121
+ M[k, j] = cj*si
1122
+ M[k, k] = cj*ci
1123
+ return M
1124
+
1125
+
1126
+ def euler_from_matrix(matrix, axes='sxyz'):
1127
+ """Return Euler angles from rotation matrix for specified axis sequence.
1128
+
1129
+ axes : One of 24 axis sequences as string or encoded tuple
1130
+
1131
+ Note that many Euler angle triplets can describe one matrix.
1132
+
1133
+ >>> R0 = euler_matrix(1, 2, 3, 'syxz')
1134
+ >>> al, be, ga = euler_from_matrix(R0, 'syxz')
1135
+ >>> R1 = euler_matrix(al, be, ga, 'syxz')
1136
+ >>> numpy.allclose(R0, R1)
1137
+ True
1138
+ >>> angles = (4*math.pi) * (numpy.random.random(3) - 0.5)
1139
+ >>> for axes in _AXES2TUPLE.keys():
1140
+ ... R0 = euler_matrix(axes=axes, *angles)
1141
+ ... R1 = euler_matrix(axes=axes, *euler_from_matrix(R0, axes))
1142
+ ... if not numpy.allclose(R0, R1): print(axes, "failed")
1143
+
1144
+ """
1145
+ try:
1146
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
1147
+ except (AttributeError, KeyError):
1148
+ _TUPLE2AXES[axes] # noqa: validation
1149
+ firstaxis, parity, repetition, frame = axes
1150
+
1151
+ i = firstaxis
1152
+ j = _NEXT_AXIS[i+parity]
1153
+ k = _NEXT_AXIS[i-parity+1]
1154
+
1155
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3]
1156
+ if repetition:
1157
+ sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k])
1158
+ if sy > _EPS:
1159
+ ax = math.atan2(M[i, j], M[i, k])
1160
+ ay = math.atan2(sy, M[i, i])
1161
+ az = math.atan2(M[j, i], -M[k, i])
1162
+ else:
1163
+ ax = math.atan2(-M[j, k], M[j, j])
1164
+ ay = math.atan2(sy, M[i, i])
1165
+ az = 0.0
1166
+ else:
1167
+ cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i])
1168
+ if cy > _EPS:
1169
+ ax = math.atan2(M[k, j], M[k, k])
1170
+ ay = math.atan2(-M[k, i], cy)
1171
+ az = math.atan2(M[j, i], M[i, i])
1172
+ else:
1173
+ ax = math.atan2(-M[j, k], M[j, j])
1174
+ ay = math.atan2(-M[k, i], cy)
1175
+ az = 0.0
1176
+
1177
+ if parity:
1178
+ ax, ay, az = -ax, -ay, -az
1179
+ if frame:
1180
+ ax, az = az, ax
1181
+ return ax, ay, az
1182
+
1183
+
1184
+ def euler_from_quaternion(quaternion, axes='sxyz'):
1185
+ """Return Euler angles from quaternion for specified axis sequence.
1186
+
1187
+ >>> angles = euler_from_quaternion([0.99810947, 0.06146124, 0, 0])
1188
+ >>> numpy.allclose(angles, [0.123, 0, 0])
1189
+ True
1190
+
1191
+ """
1192
+ return euler_from_matrix(quaternion_matrix(quaternion), axes)
1193
+
1194
+
1195
+ def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
1196
+ """Return quaternion from Euler angles and axis sequence.
1197
+
1198
+ ai, aj, ak : Euler's roll, pitch and yaw angles
1199
+ axes : One of 24 axis sequences as string or encoded tuple
1200
+
1201
+ >>> q = quaternion_from_euler(1, 2, 3, 'ryxz')
1202
+ >>> numpy.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435])
1203
+ True
1204
+
1205
+ """
1206
+ try:
1207
+ firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
1208
+ except (AttributeError, KeyError):
1209
+ _TUPLE2AXES[axes] # noqa: validation
1210
+ firstaxis, parity, repetition, frame = axes
1211
+
1212
+ i = firstaxis + 1
1213
+ j = _NEXT_AXIS[i+parity-1] + 1
1214
+ k = _NEXT_AXIS[i-parity] + 1
1215
+
1216
+ if frame:
1217
+ ai, ak = ak, ai
1218
+ if parity:
1219
+ aj = -aj
1220
+
1221
+ ai /= 2.0
1222
+ aj /= 2.0
1223
+ ak /= 2.0
1224
+ ci = math.cos(ai)
1225
+ si = math.sin(ai)
1226
+ cj = math.cos(aj)
1227
+ sj = math.sin(aj)
1228
+ ck = math.cos(ak)
1229
+ sk = math.sin(ak)
1230
+ cc = ci*ck
1231
+ cs = ci*sk
1232
+ sc = si*ck
1233
+ ss = si*sk
1234
+
1235
+ q = numpy.empty((4, ))
1236
+ if repetition:
1237
+ q[0] = cj*(cc - ss)
1238
+ q[i] = cj*(cs + sc)
1239
+ q[j] = sj*(cc + ss)
1240
+ q[k] = sj*(cs - sc)
1241
+ else:
1242
+ q[0] = cj*cc + sj*ss
1243
+ q[i] = cj*sc - sj*cs
1244
+ q[j] = cj*ss + sj*cc
1245
+ q[k] = cj*cs - sj*sc
1246
+ if parity:
1247
+ q[j] *= -1.0
1248
+
1249
+ return q
1250
+
1251
+
1252
+ def quaternion_about_axis(angle, axis):
1253
+ """Return quaternion for rotation about axis.
1254
+
1255
+ >>> q = quaternion_about_axis(0.123, [1, 0, 0])
1256
+ >>> numpy.allclose(q, [0.99810947, 0.06146124, 0, 0])
1257
+ True
1258
+
1259
+ """
1260
+ q = numpy.array([0.0, axis[0], axis[1], axis[2]])
1261
+ qlen = vector_norm(q)
1262
+ if qlen > _EPS:
1263
+ q *= math.sin(angle/2.0) / qlen
1264
+ q[0] = math.cos(angle/2.0)
1265
+ return q
1266
+
1267
+
1268
+ def quaternion_matrix(quaternion):
1269
+ """Return homogeneous rotation matrix from quaternion.
1270
+
1271
+ >>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0])
1272
+ >>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0]))
1273
+ True
1274
+ >>> M = quaternion_matrix([1, 0, 0, 0])
1275
+ >>> numpy.allclose(M, numpy.identity(4))
1276
+ True
1277
+ >>> M = quaternion_matrix([0, 1, 0, 0])
1278
+ >>> numpy.allclose(M, numpy.diag([1, -1, -1, 1]))
1279
+ True
1280
+
1281
+ """
1282
+ q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
1283
+ n = numpy.dot(q, q)
1284
+ if n < _EPS:
1285
+ return numpy.identity(4)
1286
+ q *= math.sqrt(2.0 / n)
1287
+ q = numpy.outer(q, q)
1288
+ return numpy.array([
1289
+ [1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0],
1290
+ [q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0],
1291
+ [q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0],
1292
+ [0.0, 0.0, 0.0, 1.0]])
1293
+
1294
+
1295
+ def quaternion_from_matrix(matrix, isprecise=False):
1296
+ """Return quaternion from rotation matrix.
1297
+
1298
+ If isprecise is True, the input matrix is assumed to be a precise rotation
1299
+ matrix and a faster algorithm is used.
1300
+
1301
+ >>> q = quaternion_from_matrix(numpy.identity(4), True)
1302
+ >>> numpy.allclose(q, [1, 0, 0, 0])
1303
+ True
1304
+ >>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1]))
1305
+ >>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0])
1306
+ True
1307
+ >>> R = rotation_matrix(0.123, (1, 2, 3))
1308
+ >>> q = quaternion_from_matrix(R, True)
1309
+ >>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786])
1310
+ True
1311
+ >>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0],
1312
+ ... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]]
1313
+ >>> q = quaternion_from_matrix(R)
1314
+ >>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611])
1315
+ True
1316
+ >>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0],
1317
+ ... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]]
1318
+ >>> q = quaternion_from_matrix(R)
1319
+ >>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603])
1320
+ True
1321
+ >>> R = random_rotation_matrix()
1322
+ >>> q = quaternion_from_matrix(R)
1323
+ >>> is_same_transform(R, quaternion_matrix(q))
1324
+ True
1325
+ >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
1326
+ ... quaternion_from_matrix(R, isprecise=True))
1327
+ True
1328
+ >>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0)
1329
+ >>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
1330
+ ... quaternion_from_matrix(R, isprecise=True))
1331
+ True
1332
+
1333
+ """
1334
+ M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
1335
+ if isprecise:
1336
+ q = numpy.empty((4, ))
1337
+ t = numpy.trace(M)
1338
+ if t > M[3, 3]:
1339
+ q[0] = t
1340
+ q[3] = M[1, 0] - M[0, 1]
1341
+ q[2] = M[0, 2] - M[2, 0]
1342
+ q[1] = M[2, 1] - M[1, 2]
1343
+ else:
1344
+ i, j, k = 0, 1, 2
1345
+ if M[1, 1] > M[0, 0]:
1346
+ i, j, k = 1, 2, 0
1347
+ if M[2, 2] > M[i, i]:
1348
+ i, j, k = 2, 0, 1
1349
+ t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
1350
+ q[i] = t
1351
+ q[j] = M[i, j] + M[j, i]
1352
+ q[k] = M[k, i] + M[i, k]
1353
+ q[3] = M[k, j] - M[j, k]
1354
+ q = q[[3, 0, 1, 2]]
1355
+ q *= 0.5 / math.sqrt(t * M[3, 3])
1356
+ else:
1357
+ m00 = M[0, 0]
1358
+ m01 = M[0, 1]
1359
+ m02 = M[0, 2]
1360
+ m10 = M[1, 0]
1361
+ m11 = M[1, 1]
1362
+ m12 = M[1, 2]
1363
+ m20 = M[2, 0]
1364
+ m21 = M[2, 1]
1365
+ m22 = M[2, 2]
1366
+ # symmetric matrix K
1367
+ K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0],
1368
+ [m01+m10, m11-m00-m22, 0.0, 0.0],
1369
+ [m02+m20, m12+m21, m22-m00-m11, 0.0],
1370
+ [m21-m12, m02-m20, m10-m01, m00+m11+m22]])
1371
+ K /= 3.0
1372
+ # quaternion is eigenvector of K that corresponds to largest eigenvalue
1373
+ w, V = numpy.linalg.eigh(K)
1374
+ q = V[[3, 0, 1, 2], numpy.argmax(w)]
1375
+ if q[0] < 0.0:
1376
+ numpy.negative(q, q)
1377
+ return q
1378
+
1379
+
1380
+ def quaternion_multiply(quaternion1, quaternion0):
1381
+ """Return multiplication of two quaternions.
1382
+
1383
+ >>> q = quaternion_multiply([4, 1, -2, 3], [8, -5, 6, 7])
1384
+ >>> numpy.allclose(q, [28, -44, -14, 48])
1385
+ True
1386
+
1387
+ """
1388
+ w0, x0, y0, z0 = quaternion0
1389
+ w1, x1, y1, z1 = quaternion1
1390
+ return numpy.array([
1391
+ -x1*x0 - y1*y0 - z1*z0 + w1*w0,
1392
+ x1*w0 + y1*z0 - z1*y0 + w1*x0,
1393
+ -x1*z0 + y1*w0 + z1*x0 + w1*y0,
1394
+ x1*y0 - y1*x0 + z1*w0 + w1*z0], dtype=numpy.float64)
1395
+
1396
+
1397
+ def quaternion_conjugate(quaternion):
1398
+ """Return conjugate of quaternion.
1399
+
1400
+ >>> q0 = random_quaternion()
1401
+ >>> q1 = quaternion_conjugate(q0)
1402
+ >>> q1[0] == q0[0] and all(q1[1:] == -q0[1:])
1403
+ True
1404
+
1405
+ """
1406
+ q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
1407
+ numpy.negative(q[1:], q[1:])
1408
+ return q
1409
+
1410
+
1411
+ def quaternion_inverse(quaternion):
1412
+ """Return inverse of quaternion.
1413
+
1414
+ >>> q0 = random_quaternion()
1415
+ >>> q1 = quaternion_inverse(q0)
1416
+ >>> numpy.allclose(quaternion_multiply(q0, q1), [1, 0, 0, 0])
1417
+ True
1418
+
1419
+ """
1420
+ q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
1421
+ numpy.negative(q[1:], q[1:])
1422
+ return q / numpy.dot(q, q)
1423
+
1424
+
1425
+ def quaternion_real(quaternion):
1426
+ """Return real part of quaternion.
1427
+
1428
+ >>> quaternion_real([3, 0, 1, 2])
1429
+ 3.0
1430
+
1431
+ """
1432
+ return float(quaternion[0])
1433
+
1434
+
1435
+ def quaternion_imag(quaternion):
1436
+ """Return imaginary part of quaternion.
1437
+
1438
+ >>> quaternion_imag([3, 0, 1, 2])
1439
+ array([ 0., 1., 2.])
1440
+
1441
+ """
1442
+ return numpy.array(quaternion[1:4], dtype=numpy.float64, copy=True)
1443
+
1444
+
1445
+ def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True):
1446
+ """Return spherical linear interpolation between two quaternions.
1447
+
1448
+ >>> q0 = random_quaternion()
1449
+ >>> q1 = random_quaternion()
1450
+ >>> q = quaternion_slerp(q0, q1, 0)
1451
+ >>> numpy.allclose(q, q0)
1452
+ True
1453
+ >>> q = quaternion_slerp(q0, q1, 1, 1)
1454
+ >>> numpy.allclose(q, q1)
1455
+ True
1456
+ >>> q = quaternion_slerp(q0, q1, 0.5)
1457
+ >>> angle = math.acos(numpy.dot(q0, q))
1458
+ >>> numpy.allclose(2, math.acos(numpy.dot(q0, q1)) / angle) or \
1459
+ numpy.allclose(2, math.acos(-numpy.dot(q0, q1)) / angle)
1460
+ True
1461
+
1462
+ """
1463
+ q0 = unit_vector(quat0[:4])
1464
+ q1 = unit_vector(quat1[:4])
1465
+ if fraction == 0.0:
1466
+ return q0
1467
+ elif fraction == 1.0:
1468
+ return q1
1469
+ d = numpy.dot(q0, q1)
1470
+ if abs(abs(d) - 1.0) < _EPS:
1471
+ return q0
1472
+ if shortestpath and d < 0.0:
1473
+ # invert rotation
1474
+ d = -d
1475
+ numpy.negative(q1, q1)
1476
+ angle = math.acos(d) + spin * math.pi
1477
+ if abs(angle) < _EPS:
1478
+ return q0
1479
+ isin = 1.0 / math.sin(angle)
1480
+ q0 *= math.sin((1.0 - fraction) * angle) * isin
1481
+ q1 *= math.sin(fraction * angle) * isin
1482
+ q0 += q1
1483
+ return q0
1484
+
1485
+
1486
+ def random_quaternion(rand=None):
1487
+ """Return uniform random unit quaternion.
1488
+
1489
+ rand: array like or None
1490
+ Three independent random variables that are uniformly distributed
1491
+ between 0 and 1.
1492
+
1493
+ >>> q = random_quaternion()
1494
+ >>> numpy.allclose(1, vector_norm(q))
1495
+ True
1496
+ >>> q = random_quaternion(numpy.random.random(3))
1497
+ >>> len(q.shape), q.shape[0]==4
1498
+ (1, True)
1499
+
1500
+ """
1501
+ if rand is None:
1502
+ rand = numpy.random.rand(3)
1503
+ else:
1504
+ assert len(rand) == 3
1505
+ r1 = numpy.sqrt(1.0 - rand[0])
1506
+ r2 = numpy.sqrt(rand[0])
1507
+ pi2 = math.pi * 2.0
1508
+ t1 = pi2 * rand[1]
1509
+ t2 = pi2 * rand[2]
1510
+ return numpy.array([numpy.cos(t2)*r2, numpy.sin(t1)*r1,
1511
+ numpy.cos(t1)*r1, numpy.sin(t2)*r2])
1512
+
1513
+
1514
+ def random_rotation_matrix(rand=None):
1515
+ """Return uniform random rotation matrix.
1516
+
1517
+ rand: array like
1518
+ Three independent random variables that are uniformly distributed
1519
+ between 0 and 1 for each returned quaternion.
1520
+
1521
+ >>> R = random_rotation_matrix()
1522
+ >>> numpy.allclose(numpy.dot(R.T, R), numpy.identity(4))
1523
+ True
1524
+
1525
+ """
1526
+ return quaternion_matrix(random_quaternion(rand))
1527
+
1528
+
1529
+ class Arcball(object):
1530
+ """Virtual Trackball Control.
1531
+
1532
+ >>> ball = Arcball()
1533
+ >>> ball = Arcball(initial=numpy.identity(4))
1534
+ >>> ball.place([320, 320], 320)
1535
+ >>> ball.down([500, 250])
1536
+ >>> ball.drag([475, 275])
1537
+ >>> R = ball.matrix()
1538
+ >>> numpy.allclose(numpy.sum(R), 3.90583455)
1539
+ True
1540
+ >>> ball = Arcball(initial=[1, 0, 0, 0])
1541
+ >>> ball.place([320, 320], 320)
1542
+ >>> ball.setaxes([1, 1, 0], [-1, 1, 0])
1543
+ >>> ball.constrain = True
1544
+ >>> ball.down([400, 200])
1545
+ >>> ball.drag([200, 400])
1546
+ >>> R = ball.matrix()
1547
+ >>> numpy.allclose(numpy.sum(R), 0.2055924)
1548
+ True
1549
+ >>> ball.next()
1550
+
1551
+ """
1552
+
1553
+ def __init__(self, initial=None):
1554
+ """Initialize virtual trackball control.
1555
+
1556
+ initial : quaternion or rotation matrix
1557
+
1558
+ """
1559
+ self._axis = None
1560
+ self._axes = None
1561
+ self._radius = 1.0
1562
+ self._center = [0.0, 0.0]
1563
+ self._vdown = numpy.array([0.0, 0.0, 1.0])
1564
+ self._constrain = False
1565
+ if initial is None:
1566
+ self._qdown = numpy.array([1.0, 0.0, 0.0, 0.0])
1567
+ else:
1568
+ initial = numpy.array(initial, dtype=numpy.float64)
1569
+ if initial.shape == (4, 4):
1570
+ self._qdown = quaternion_from_matrix(initial)
1571
+ elif initial.shape == (4, ):
1572
+ initial /= vector_norm(initial)
1573
+ self._qdown = initial
1574
+ else:
1575
+ raise ValueError("initial not a quaternion or matrix")
1576
+ self._qnow = self._qpre = self._qdown
1577
+
1578
+ def place(self, center, radius):
1579
+ """Place Arcball, e.g. when window size changes.
1580
+
1581
+ center : sequence[2]
1582
+ Window coordinates of trackball center.
1583
+ radius : float
1584
+ Radius of trackball in window coordinates.
1585
+
1586
+ """
1587
+ self._radius = float(radius)
1588
+ self._center[0] = center[0]
1589
+ self._center[1] = center[1]
1590
+
1591
+ def setaxes(self, *axes):
1592
+ """Set axes to constrain rotations."""
1593
+ if axes is None:
1594
+ self._axes = None
1595
+ else:
1596
+ self._axes = [unit_vector(axis) for axis in axes]
1597
+
1598
+ @property
1599
+ def constrain(self):
1600
+ """Return state of constrain to axis mode."""
1601
+ return self._constrain
1602
+
1603
+ @constrain.setter
1604
+ def constrain(self, value):
1605
+ """Set state of constrain to axis mode."""
1606
+ self._constrain = bool(value)
1607
+
1608
+ def down(self, point):
1609
+ """Set initial cursor window coordinates and pick constrain-axis."""
1610
+ self._vdown = arcball_map_to_sphere(point, self._center, self._radius)
1611
+ self._qdown = self._qpre = self._qnow
1612
+ if self._constrain and self._axes is not None:
1613
+ self._axis = arcball_nearest_axis(self._vdown, self._axes)
1614
+ self._vdown = arcball_constrain_to_axis(self._vdown, self._axis)
1615
+ else:
1616
+ self._axis = None
1617
+
1618
+ def drag(self, point):
1619
+ """Update current cursor window coordinates."""
1620
+ vnow = arcball_map_to_sphere(point, self._center, self._radius)
1621
+ if self._axis is not None:
1622
+ vnow = arcball_constrain_to_axis(vnow, self._axis)
1623
+ self._qpre = self._qnow
1624
+ t = numpy.cross(self._vdown, vnow)
1625
+ if numpy.dot(t, t) < _EPS:
1626
+ self._qnow = self._qdown
1627
+ else:
1628
+ q = [numpy.dot(self._vdown, vnow), t[0], t[1], t[2]]
1629
+ self._qnow = quaternion_multiply(q, self._qdown)
1630
+
1631
+ def next(self, acceleration=0.0):
1632
+ """Continue rotation in direction of last drag."""
1633
+ q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False)
1634
+ self._qpre, self._qnow = self._qnow, q
1635
+
1636
+ def matrix(self):
1637
+ """Return homogeneous rotation matrix."""
1638
+ return quaternion_matrix(self._qnow)
1639
+
1640
+
1641
+ def arcball_map_to_sphere(point, center, radius):
1642
+ """Return unit sphere coordinates from window coordinates."""
1643
+ v0 = (point[0] - center[0]) / radius
1644
+ v1 = (center[1] - point[1]) / radius
1645
+ n = v0*v0 + v1*v1
1646
+ if n > 1.0:
1647
+ # position outside of sphere
1648
+ n = math.sqrt(n)
1649
+ return numpy.array([v0/n, v1/n, 0.0])
1650
+ else:
1651
+ return numpy.array([v0, v1, math.sqrt(1.0 - n)])
1652
+
1653
+
1654
+ def arcball_constrain_to_axis(point, axis):
1655
+ """Return sphere point perpendicular to axis."""
1656
+ v = numpy.array(point, dtype=numpy.float64, copy=True)
1657
+ a = numpy.array(axis, dtype=numpy.float64, copy=True)
1658
+ v -= a * numpy.dot(a, v) # on plane
1659
+ n = vector_norm(v)
1660
+ if n > _EPS:
1661
+ if v[2] < 0.0:
1662
+ numpy.negative(v, v)
1663
+ v /= n
1664
+ return v
1665
+ if a[2] == 1.0:
1666
+ return numpy.array([1.0, 0.0, 0.0])
1667
+ return unit_vector([-a[1], a[0], 0.0])
1668
+
1669
+
1670
+ def arcball_nearest_axis(point, axes):
1671
+ """Return axis, which arc is nearest to point."""
1672
+ point = numpy.array(point, dtype=numpy.float64, copy=False)
1673
+ nearest = None
1674
+ mx = -1.0
1675
+ for axis in axes:
1676
+ t = numpy.dot(arcball_constrain_to_axis(point, axis), point)
1677
+ if t > mx:
1678
+ nearest = axis
1679
+ mx = t
1680
+ return nearest
1681
+
1682
+
1683
+ # epsilon for testing whether a number is close to zero
1684
+ _EPS = numpy.finfo(float).eps * 4.0
1685
+
1686
+ # axis sequences for Euler angles
1687
+ _NEXT_AXIS = [1, 2, 0, 1]
1688
+
1689
+ # map axes strings to/from tuples of inner axis, parity, repetition, frame
1690
+ _AXES2TUPLE = {
1691
+ 'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0),
1692
+ 'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0),
1693
+ 'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0),
1694
+ 'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0),
1695
+ 'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1),
1696
+ 'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1),
1697
+ 'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1),
1698
+ 'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)}
1699
+
1700
+ _TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
1701
+
1702
+
1703
+ def vector_norm(data, axis=None, out=None):
1704
+ """Return length, i.e. Euclidean norm, of ndarray along axis.
1705
+
1706
+ >>> v = numpy.random.random(3)
1707
+ >>> n = vector_norm(v)
1708
+ >>> numpy.allclose(n, numpy.linalg.norm(v))
1709
+ True
1710
+ >>> v = numpy.random.rand(6, 5, 3)
1711
+ >>> n = vector_norm(v, axis=-1)
1712
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=2)))
1713
+ True
1714
+ >>> n = vector_norm(v, axis=1)
1715
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
1716
+ True
1717
+ >>> v = numpy.random.rand(5, 4, 3)
1718
+ >>> n = numpy.empty((5, 3))
1719
+ >>> vector_norm(v, axis=1, out=n)
1720
+ >>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
1721
+ True
1722
+ >>> vector_norm([])
1723
+ 0.0
1724
+ >>> vector_norm([1])
1725
+ 1.0
1726
+
1727
+ """
1728
+ data = numpy.array(data, dtype=numpy.float64, copy=True)
1729
+ if out is None:
1730
+ if data.ndim == 1:
1731
+ return math.sqrt(numpy.dot(data, data))
1732
+ data *= data
1733
+ out = numpy.atleast_1d(numpy.sum(data, axis=axis))
1734
+ numpy.sqrt(out, out)
1735
+ return out
1736
+ else:
1737
+ data *= data
1738
+ numpy.sum(data, axis=axis, out=out)
1739
+ numpy.sqrt(out, out)
1740
+
1741
+
1742
+ def unit_vector(data, axis=None, out=None):
1743
+ """Return ndarray normalized by length, i.e. Euclidean norm, along axis.
1744
+
1745
+ >>> v0 = numpy.random.random(3)
1746
+ >>> v1 = unit_vector(v0)
1747
+ >>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0))
1748
+ True
1749
+ >>> v0 = numpy.random.rand(5, 4, 3)
1750
+ >>> v1 = unit_vector(v0, axis=-1)
1751
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2)
1752
+ >>> numpy.allclose(v1, v2)
1753
+ True
1754
+ >>> v1 = unit_vector(v0, axis=1)
1755
+ >>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1)
1756
+ >>> numpy.allclose(v1, v2)
1757
+ True
1758
+ >>> v1 = numpy.empty((5, 4, 3))
1759
+ >>> unit_vector(v0, axis=1, out=v1)
1760
+ >>> numpy.allclose(v1, v2)
1761
+ True
1762
+ >>> list(unit_vector([]))
1763
+ []
1764
+ >>> list(unit_vector([1]))
1765
+ [1.0]
1766
+
1767
+ """
1768
+ if out is None:
1769
+ data = numpy.array(data, dtype=numpy.float64, copy=True)
1770
+ if data.ndim == 1:
1771
+ data /= math.sqrt(numpy.dot(data, data))
1772
+ return data
1773
+ else:
1774
+ if out is not data:
1775
+ out[:] = numpy.array(data, copy=False)
1776
+ data = out
1777
+ length = numpy.atleast_1d(numpy.sum(data*data, axis))
1778
+ numpy.sqrt(length, length)
1779
+ if axis is not None:
1780
+ length = numpy.expand_dims(length, axis)
1781
+ data /= length
1782
+ if out is None:
1783
+ return data
1784
+
1785
+
1786
+ def random_vector(size):
1787
+ """Return array of random doubles in the half-open interval [0.0, 1.0).
1788
+
1789
+ >>> v = random_vector(10000)
1790
+ >>> numpy.all(v >= 0) and numpy.all(v < 1)
1791
+ True
1792
+ >>> v0 = random_vector(10)
1793
+ >>> v1 = random_vector(10)
1794
+ >>> numpy.any(v0 == v1)
1795
+ False
1796
+
1797
+ """
1798
+ return numpy.random.random(size)
1799
+
1800
+
1801
+ def vector_product(v0, v1, axis=0):
1802
+ """Return vector perpendicular to vectors.
1803
+
1804
+ >>> v = vector_product([2, 0, 0], [0, 3, 0])
1805
+ >>> numpy.allclose(v, [0, 0, 6])
1806
+ True
1807
+ >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]]
1808
+ >>> v1 = [[3], [0], [0]]
1809
+ >>> v = vector_product(v0, v1)
1810
+ >>> numpy.allclose(v, [[0, 0, 0, 0], [0, 0, 6, 6], [0, -6, 0, -6]])
1811
+ True
1812
+ >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]]
1813
+ >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]]
1814
+ >>> v = vector_product(v0, v1, axis=1)
1815
+ >>> numpy.allclose(v, [[0, 0, 6], [0, -6, 0], [6, 0, 0], [0, -6, 6]])
1816
+ True
1817
+
1818
+ """
1819
+ return numpy.cross(v0, v1, axis=axis)
1820
+
1821
+
1822
+ def angle_between_vectors(v0, v1, directed=True, axis=0):
1823
+ """Return angle between vectors.
1824
+
1825
+ If directed is False, the input vectors are interpreted as undirected axes,
1826
+ i.e. the maximum angle is pi/2.
1827
+
1828
+ >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3])
1829
+ >>> numpy.allclose(a, math.pi)
1830
+ True
1831
+ >>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3], directed=False)
1832
+ >>> numpy.allclose(a, 0)
1833
+ True
1834
+ >>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]]
1835
+ >>> v1 = [[3], [0], [0]]
1836
+ >>> a = angle_between_vectors(v0, v1)
1837
+ >>> numpy.allclose(a, [0, 1.5708, 1.5708, 0.95532])
1838
+ True
1839
+ >>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]]
1840
+ >>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]]
1841
+ >>> a = angle_between_vectors(v0, v1, axis=1)
1842
+ >>> numpy.allclose(a, [1.5708, 1.5708, 1.5708, 0.95532])
1843
+ True
1844
+
1845
+ """
1846
+ v0 = numpy.array(v0, dtype=numpy.float64, copy=False)
1847
+ v1 = numpy.array(v1, dtype=numpy.float64, copy=False)
1848
+ dot = numpy.sum(v0 * v1, axis=axis)
1849
+ dot /= vector_norm(v0, axis=axis) * vector_norm(v1, axis=axis)
1850
+ dot = numpy.clip(dot, -1.0, 1.0)
1851
+ return numpy.arccos(dot if directed else numpy.fabs(dot))
1852
+
1853
+
1854
+ def inverse_matrix(matrix):
1855
+ """Return inverse of square transformation matrix.
1856
+
1857
+ >>> M0 = random_rotation_matrix()
1858
+ >>> M1 = inverse_matrix(M0.T)
1859
+ >>> numpy.allclose(M1, numpy.linalg.inv(M0.T))
1860
+ True
1861
+ >>> for size in range(1, 7):
1862
+ ... M0 = numpy.random.rand(size, size)
1863
+ ... M1 = inverse_matrix(M0)
1864
+ ... if not numpy.allclose(M1, numpy.linalg.inv(M0)): print(size)
1865
+
1866
+ """
1867
+ return numpy.linalg.inv(matrix)
1868
+
1869
+
1870
+ def concatenate_matrices(*matrices):
1871
+ """Return concatenation of series of transformation matrices.
1872
+
1873
+ >>> M = numpy.random.rand(16).reshape((4, 4)) - 0.5
1874
+ >>> numpy.allclose(M, concatenate_matrices(M))
1875
+ True
1876
+ >>> numpy.allclose(numpy.dot(M, M.T), concatenate_matrices(M, M.T))
1877
+ True
1878
+
1879
+ """
1880
+ M = numpy.identity(4)
1881
+ for i in matrices:
1882
+ M = numpy.dot(M, i)
1883
+ return M
1884
+
1885
+
1886
+ def is_same_transform(matrix0, matrix1):
1887
+ """Return True if two matrices perform same transformation.
1888
+
1889
+ >>> is_same_transform(numpy.identity(4), numpy.identity(4))
1890
+ True
1891
+ >>> is_same_transform(numpy.identity(4), random_rotation_matrix())
1892
+ False
1893
+
1894
+ """
1895
+ matrix0 = numpy.array(matrix0, dtype=numpy.float64, copy=True)
1896
+ matrix0 /= matrix0[3, 3]
1897
+ matrix1 = numpy.array(matrix1, dtype=numpy.float64, copy=True)
1898
+ matrix1 /= matrix1[3, 3]
1899
+ return numpy.allclose(matrix0, matrix1)
1900
+
1901
+
1902
+ def is_same_quaternion(q0, q1):
1903
+ """Return True if two quaternions are equal."""
1904
+ q0 = numpy.array(q0)
1905
+ q1 = numpy.array(q1)
1906
+ return numpy.allclose(q0, q1) or numpy.allclose(q0, -q1)
1907
+
1908
+
1909
+ def _import_module(name, package=None, warn=True, postfix='_py', ignore='_'):
1910
+ """Try import all public attributes from module into global namespace.
1911
+
1912
+ Existing attributes with name clashes are renamed with prefix.
1913
+ Attributes starting with underscore are ignored by default.
1914
+
1915
+ Return True on successful import.
1916
+
1917
+ """
1918
+ import warnings
1919
+ from importlib import import_module
1920
+ try:
1921
+ if not package:
1922
+ module = import_module(name)
1923
+ else:
1924
+ module = import_module('.' + name, package=package)
1925
+ except ImportError as err:
1926
+ if warn:
1927
+ warnings.warn(str(err))
1928
+ else:
1929
+ for attr in dir(module):
1930
+ if ignore and attr.startswith(ignore):
1931
+ continue
1932
+ if postfix:
1933
+ if attr in globals():
1934
+ globals()[attr + postfix] = globals()[attr]
1935
+ elif warn:
1936
+ warnings.warn('no Python implementation of ' + attr)
1937
+ globals()[attr] = getattr(module, attr)
1938
+ return True
1939
+
1940
+
1941
+ _import_module('_transformations', __package__, warn=False)
1942
+
1943
+
1944
+ if __name__ == '__main__':
1945
+ import doctest
1946
+ import random # noqa: used in doctests
1947
+ try:
1948
+ numpy.set_printoptions(suppress=True, precision=5, legacy='1.13')
1949
+ except TypeError:
1950
+ numpy.set_printoptions(suppress=True, precision=5)
1951
+ doctest.testmod()
third_party/COTR/COTR/utils/constants.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ DEFAULT_PRECISION = 'float32'
2
+ MAX_SIZE = 256
3
+ VALID_NN_OVERLAPPING_THRESH = 0.1
third_party/COTR/COTR/utils/debug_utils.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def embed_breakpoint(debug_info='', terminate=True):
2
+ print('\nyou are inside a break point')
3
+ if debug_info:
4
+ print('debug info: {0}'.format(debug_info))
5
+ print('')
6
+ embedding = ('import IPython\n'
7
+ 'import matplotlib.pyplot as plt\n'
8
+ 'IPython.embed()\n'
9
+ )
10
+ if terminate:
11
+ embedding += (
12
+ 'assert 0, \'force termination\'\n'
13
+ )
14
+
15
+ return embedding
third_party/COTR/COTR/utils/utils.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import smtplib
3
+ import ssl
4
+ from collections import namedtuple
5
+
6
+ from COTR.utils import debug_utils
7
+
8
+ import numpy as np
9
+ import torch
10
+ import cv2
11
+ import matplotlib.pyplot as plt
12
+ import PIL
13
+
14
+
15
+ '''
16
+ ImagePatch: patch: patch content, np array or None
17
+ x: left bound in original resolution
18
+ y: upper bound in original resolution
19
+ w: width of patch
20
+ h: height of patch
21
+ ow: width of original resolution
22
+ oh: height of original resolution
23
+ '''
24
+ ImagePatch = namedtuple('ImagePatch', ['patch', 'x', 'y', 'w', 'h', 'ow', 'oh'])
25
+ Point3D = namedtuple("Point3D", ["id", "arr_idx", "image_ids"])
26
+ Point2D = namedtuple("Point2D", ["id_3d", "xy"])
27
+
28
+
29
+ class CropCamConfig():
30
+ def __init__(self, x, y, w, h, out_w, out_h, orig_w, orig_h):
31
+ '''
32
+ xy: left upper corner
33
+ '''
34
+ # assert x > 0 and x < orig_w
35
+ # assert y > 0 and y < orig_h
36
+ # assert w < orig_w and h < orig_h
37
+ # assert x - w / 2 > 0 and x + w / 2 < orig_w
38
+ # assert y - h / 2 > 0 and y + h / 2 < orig_h
39
+ # assert h / w == out_h / out_w
40
+ self.x = x
41
+ self.y = y
42
+ self.w = w
43
+ self.h = h
44
+ self.out_w = out_w
45
+ self.out_h = out_h
46
+ self.orig_w = orig_w
47
+ self.orig_h = orig_h
48
+
49
+ def __str__(self):
50
+ out = f'original image size(h,w): [{self.orig_h}, {self.orig_w}]\n'
51
+ out += f'crop at(x,y): [{self.x}, {self.y}]\n'
52
+ out += f'crop size(h,w): [{self.h}, {self.w}]\n'
53
+ out += f'resize crop to(h,w): [{self.out_h}, {self.out_w}]'
54
+ return out
55
+
56
+
57
+ def fix_randomness(seed=42):
58
+ random.seed(seed)
59
+ torch.backends.cudnn.deterministic = True
60
+ torch.backends.cudnn.benchmark = False
61
+ torch.manual_seed(seed)
62
+ np.random.seed(seed)
63
+
64
+
65
+ def worker_init_fn(worker_id):
66
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
67
+
68
+
69
+ def float_image_resize(img, shape, interp=PIL.Image.BILINEAR):
70
+ missing_channel = False
71
+ if len(img.shape) == 2:
72
+ missing_channel = True
73
+ img = img[..., None]
74
+ layers = []
75
+ img = img.transpose(2, 0, 1)
76
+ for l in img:
77
+ l = np.array(PIL.Image.fromarray(l).resize(shape[::-1], resample=interp))
78
+ assert l.shape[:2] == shape
79
+ layers.append(l)
80
+ if missing_channel:
81
+ return np.stack(layers, axis=-1)[..., 0]
82
+ else:
83
+ return np.stack(layers, axis=-1)
84
+
85
+
86
+ def is_nan(x):
87
+ """
88
+ get mask of nan values.
89
+ :param x: torch or numpy var.
90
+ :return: a N-D array of bool. True -> nan, False -> ok.
91
+ """
92
+ return x != x
93
+
94
+
95
+ def has_nan(x) -> bool:
96
+ """
97
+ check whether x contains nan.
98
+ :param x: torch or numpy var.
99
+ :return: single bool, True -> x containing nan, False -> ok.
100
+ """
101
+ if x is None:
102
+ return False
103
+ return is_nan(x).any()
104
+
105
+
106
+ def confirm(question='OK to continue?'):
107
+ """
108
+ Ask user to enter Y or N (case-insensitive).
109
+ :return: True if the answer is Y.
110
+ :rtype: bool
111
+ """
112
+ answer = ""
113
+ while answer not in ["y", "n"]:
114
+ answer = input(question + ' [y/n] ').lower()
115
+ return answer == "y"
116
+
117
+
118
+ def print_notification(content_list, notification_type='NOTIFICATION'):
119
+ print('---------------------- {0} ----------------------'.format(notification_type))
120
+ print()
121
+ for content in content_list:
122
+ print(content)
123
+ print()
124
+ print('----------------------------------------------------')
125
+
126
+
127
+ def torch_img_to_np_img(torch_img):
128
+ '''convert a torch image to matplotlib-able numpy image
129
+ torch use Channels x Height x Width
130
+ numpy use Height x Width x Channels
131
+ Arguments:
132
+ torch_img {[type]} -- [description]
133
+ '''
134
+ assert isinstance(torch_img, torch.Tensor), 'cannot process data type: {0}'.format(type(torch_img))
135
+ if len(torch_img.shape) == 4 and (torch_img.shape[1] == 3 or torch_img.shape[1] == 1):
136
+ return np.transpose(torch_img.detach().cpu().numpy(), (0, 2, 3, 1))
137
+ if len(torch_img.shape) == 3 and (torch_img.shape[0] == 3 or torch_img.shape[0] == 1):
138
+ return np.transpose(torch_img.detach().cpu().numpy(), (1, 2, 0))
139
+ elif len(torch_img.shape) == 2:
140
+ return torch_img.detach().cpu().numpy()
141
+ else:
142
+ raise ValueError('cannot process this image')
143
+
144
+
145
+ def np_img_to_torch_img(np_img):
146
+ """convert a numpy image to torch image
147
+ numpy use Height x Width x Channels
148
+ torch use Channels x Height x Width
149
+
150
+ Arguments:
151
+ np_img {[type]} -- [description]
152
+ """
153
+ assert isinstance(np_img, np.ndarray), 'cannot process data type: {0}'.format(type(np_img))
154
+ if len(np_img.shape) == 4 and (np_img.shape[3] == 3 or np_img.shape[3] == 1):
155
+ return torch.from_numpy(np.transpose(np_img, (0, 3, 1, 2)))
156
+ if len(np_img.shape) == 3 and (np_img.shape[2] == 3 or np_img.shape[2] == 1):
157
+ return torch.from_numpy(np.transpose(np_img, (2, 0, 1)))
158
+ elif len(np_img.shape) == 2:
159
+ return torch.from_numpy(np_img)
160
+ else:
161
+ raise ValueError('cannot process this image with shape: {0}'.format(np_img.shape))
162
+
163
+
164
+ def safe_load_weights(model, saved_weights):
165
+ try:
166
+ model.load_state_dict(saved_weights)
167
+ except RuntimeError:
168
+ try:
169
+ weights = saved_weights
170
+ weights = {k.replace('module.', ''): v for k, v in weights.items()}
171
+ model.load_state_dict(weights)
172
+ except RuntimeError:
173
+ try:
174
+ weights = saved_weights
175
+ weights = {'module.' + k: v for k, v in weights.items()}
176
+ model.load_state_dict(weights)
177
+ except RuntimeError:
178
+ try:
179
+ pretrained_dict = saved_weights
180
+ model_dict = model.state_dict()
181
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if ((k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape))}
182
+ assert len(pretrained_dict) != 0
183
+ model_dict.update(pretrained_dict)
184
+ model.load_state_dict(model_dict)
185
+ non_match_keys = set(model.state_dict().keys()) - set(pretrained_dict.keys())
186
+ notification = []
187
+ notification += ['pretrained weights PARTIALLY loaded, following are missing:']
188
+ notification += [str(non_match_keys)]
189
+ print_notification(notification, 'WARNING')
190
+ except Exception as e:
191
+ print(f'pretrained weights loading failed {e}')
192
+ exit()
193
+ print('weights safely loaded')
194
+
195
+
196
+ def visualize_corrs(img1, img2, corrs, mask=None):
197
+ if mask is None:
198
+ mask = np.ones(len(corrs)).astype(bool)
199
+
200
+ scale1 = 1.0
201
+ scale2 = 1.0
202
+ if img1.shape[1] > img2.shape[1]:
203
+ scale2 = img1.shape[1] / img2.shape[1]
204
+ w = img1.shape[1]
205
+ else:
206
+ scale1 = img2.shape[1] / img1.shape[1]
207
+ w = img2.shape[1]
208
+ # Resize if too big
209
+ max_w = 400
210
+ if w > max_w:
211
+ scale1 *= max_w / w
212
+ scale2 *= max_w / w
213
+ img1 = cv2.resize(img1, (0, 0), fx=scale1, fy=scale1)
214
+ img2 = cv2.resize(img2, (0, 0), fx=scale2, fy=scale2)
215
+
216
+ x1, x2 = corrs[:, :2], corrs[:, 2:]
217
+ h1, w1 = img1.shape[:2]
218
+ h2, w2 = img2.shape[:2]
219
+ img = np.zeros((h1 + h2, max(w1, w2), 3), dtype=img1.dtype)
220
+ img[:h1, :w1] = img1
221
+ img[h1:, :w2] = img2
222
+ # Move keypoints to coordinates to image coordinates
223
+ x1 = x1 * scale1
224
+ x2 = x2 * scale2
225
+ # recompute the coordinates for the second image
226
+ x2p = x2 + np.array([[0, h1]])
227
+ fig = plt.figure(frameon=False)
228
+ fig = plt.imshow(img)
229
+
230
+ cols = [
231
+ [0.0, 0.67, 0.0],
232
+ [0.9, 0.1, 0.1],
233
+ ]
234
+ lw = .5
235
+ alpha = 1
236
+
237
+ # Draw outliers
238
+ _x1 = x1[~mask]
239
+ _x2p = x2p[~mask]
240
+ xs = np.stack([_x1[:, 0], _x2p[:, 0]], axis=1).T
241
+ ys = np.stack([_x1[:, 1], _x2p[:, 1]], axis=1).T
242
+ plt.plot(
243
+ xs, ys,
244
+ alpha=alpha,
245
+ linestyle="-",
246
+ linewidth=lw,
247
+ aa=False,
248
+ color=cols[1],
249
+ )
250
+
251
+
252
+ # Draw Inliers
253
+ _x1 = x1[mask]
254
+ _x2p = x2p[mask]
255
+ xs = np.stack([_x1[:, 0], _x2p[:, 0]], axis=1).T
256
+ ys = np.stack([_x1[:, 1], _x2p[:, 1]], axis=1).T
257
+ plt.plot(
258
+ xs, ys,
259
+ alpha=alpha,
260
+ linestyle="-",
261
+ linewidth=lw,
262
+ aa=False,
263
+ color=cols[0],
264
+ )
265
+ plt.scatter(xs, ys)
266
+
267
+ fig.axes.get_xaxis().set_visible(False)
268
+ fig.axes.get_yaxis().set_visible(False)
269
+ ax = plt.gca()
270
+ ax.set_axis_off()
271
+ plt.show()
third_party/COTR/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
third_party/COTR/demo_face.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ COTR demo for human face
3
+ We use an off-the-shelf face landmarks detector: https://github.com/1adrianb/face-alignment
4
+ '''
5
+ import argparse
6
+ import os
7
+ import time
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import torch
12
+ import imageio
13
+ import matplotlib.pyplot as plt
14
+
15
+ from COTR.utils import utils, debug_utils
16
+ from COTR.models import build_model
17
+ from COTR.options.options import *
18
+ from COTR.options.options_utils import *
19
+ from COTR.inference.inference_helper import triangulate_corr
20
+ from COTR.inference.sparse_engine import SparseEngine
21
+
22
+ utils.fix_randomness(0)
23
+ torch.set_grad_enabled(False)
24
+
25
+
26
+ def main(opt):
27
+ model = build_model(opt)
28
+ model = model.cuda()
29
+ weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
30
+ utils.safe_load_weights(model, weights)
31
+ model = model.eval()
32
+
33
+ img_a = imageio.imread('./sample_data/imgs/face_1.png', pilmode='RGB')
34
+ img_b = imageio.imread('./sample_data/imgs/face_2.png', pilmode='RGB')
35
+ queries = np.load('./sample_data/face_landmarks.npy')[0]
36
+
37
+ engine = SparseEngine(model, 32, mode='stretching')
38
+ corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, queries_a=queries, force=False)
39
+
40
+ f, axarr = plt.subplots(1, 2)
41
+ axarr[0].imshow(img_a)
42
+ axarr[0].scatter(*queries.T, s=1)
43
+ axarr[0].title.set_text('Reference Face')
44
+ axarr[0].axis('off')
45
+ axarr[1].imshow(img_b)
46
+ axarr[1].scatter(*corrs[:, 2:].T, s=1)
47
+ axarr[1].title.set_text('Target Face')
48
+ axarr[1].axis('off')
49
+ plt.show()
50
+
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser()
54
+ set_COTR_arguments(parser)
55
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
56
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
57
+
58
+ opt = parser.parse_args()
59
+ opt.command = ' '.join(sys.argv)
60
+
61
+ layer_2_channels = {'layer1': 256,
62
+ 'layer2': 512,
63
+ 'layer3': 1024,
64
+ 'layer4': 2048, }
65
+ opt.dim_feedforward = layer_2_channels[opt.layer]
66
+ if opt.load_weights:
67
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
68
+ print_opt(opt)
69
+ main(opt)
third_party/COTR/demo_guided_matching.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Feature-free COTR guided matching for keypoints.
3
+ We use DISK(https://github.com/cvlab-epfl/disk) keypoints location.
4
+ We apply RANSAC + F matrix to further prune outliers.
5
+ Note: This script doesn't use descriptors.
6
+ '''
7
+ import argparse
8
+ import os
9
+ import time
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import torch
14
+ import imageio
15
+ from scipy.spatial import distance_matrix
16
+
17
+ from COTR.utils import utils, debug_utils
18
+ from COTR.models import build_model
19
+ from COTR.options.options import *
20
+ from COTR.options.options_utils import *
21
+ from COTR.inference.sparse_engine import SparseEngine, FasterSparseEngine
22
+
23
+ utils.fix_randomness(0)
24
+ torch.set_grad_enabled(False)
25
+
26
+
27
+ def main(opt):
28
+ model = build_model(opt)
29
+ model = model.cuda()
30
+ weights = torch.load(opt.load_weights_path)['model_state_dict']
31
+ utils.safe_load_weights(model, weights)
32
+ model = model.eval()
33
+
34
+ img_a = imageio.imread('./sample_data/imgs/21526113_4379776807.jpg')
35
+ img_b = imageio.imread('./sample_data/imgs/21126421_4537535153.jpg')
36
+ kp_a = np.load('./sample_data/21526113_4379776807.jpg.disk.kpts.npy')
37
+ kp_b = np.load('./sample_data/21126421_4537535153.jpg.disk.kpts.npy')
38
+
39
+ if opt.faster_infer:
40
+ engine = FasterSparseEngine(model, 32, mode='tile')
41
+ else:
42
+ engine = SparseEngine(model, 32, mode='tile')
43
+ t0 = time.time()
44
+ corrs_a_b = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=kp_a.shape[0], queries_a=kp_a, force=True)
45
+ corrs_b_a = engine.cotr_corr_multiscale(img_b, img_a, np.linspace(0.5, 0.0625, 4), 1, max_corrs=kp_b.shape[0], queries_a=kp_b, force=True)
46
+ t1 = time.time()
47
+ print(f'COTR spent {t1-t0} seconds.')
48
+ inds_a_b = np.argmin(distance_matrix(corrs_a_b[:, 2:], kp_b), axis=1)
49
+ matched_a_b = np.stack([np.arange(kp_a.shape[0]), inds_a_b]).T
50
+ inds_b_a = np.argmin(distance_matrix(corrs_b_a[:, 2:], kp_a), axis=1)
51
+ matched_b_a = np.stack([np.arange(kp_b.shape[0]), inds_b_a]).T
52
+
53
+ good = 0
54
+ final_matches = []
55
+ for m_ab in matched_a_b:
56
+ for m_ba in matched_b_a:
57
+ if (m_ab == m_ba[::-1]).all():
58
+ good += 1
59
+ final_matches.append(m_ab)
60
+ break
61
+ final_matches = np.array(final_matches)
62
+ final_corrs = np.concatenate([kp_a[final_matches[:, 0]], kp_b[final_matches[:, 1]]], axis=1)
63
+ _, mask = cv2.findFundamentalMat(final_corrs[:, :2], final_corrs[:, 2:], cv2.FM_RANSAC, ransacReprojThreshold=5, confidence=0.999999)
64
+ utils.visualize_corrs(img_a, img_b, final_corrs[np.where(mask[:, 0])])
65
+
66
+
67
+ if __name__ == "__main__":
68
+ parser = argparse.ArgumentParser()
69
+ set_COTR_arguments(parser)
70
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
71
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
72
+ parser.add_argument('--faster_infer', type=str2bool, default=False, help='use fatser inference')
73
+
74
+ opt = parser.parse_args()
75
+ opt.command = ' '.join(sys.argv)
76
+
77
+ layer_2_channels = {'layer1': 256,
78
+ 'layer2': 512,
79
+ 'layer3': 1024,
80
+ 'layer4': 2048, }
81
+ opt.dim_feedforward = layer_2_channels[opt.layer]
82
+ if opt.load_weights:
83
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
84
+ print_opt(opt)
85
+ main(opt)
third_party/COTR/demo_homography.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ COTR demo for homography estimation
3
+ '''
4
+ import argparse
5
+ import os
6
+ import time
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ import imageio
12
+ import matplotlib.pyplot as plt
13
+
14
+ from COTR.utils import utils, debug_utils
15
+ from COTR.models import build_model
16
+ from COTR.options.options import *
17
+ from COTR.options.options_utils import *
18
+ from COTR.inference.inference_helper import triangulate_corr
19
+ from COTR.inference.sparse_engine import SparseEngine
20
+
21
+ utils.fix_randomness(0)
22
+ torch.set_grad_enabled(False)
23
+
24
+
25
+ def main(opt):
26
+ model = build_model(opt)
27
+ model = model.cuda()
28
+ weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
29
+ utils.safe_load_weights(model, weights)
30
+ model = model.eval()
31
+
32
+ img_a = imageio.imread('./sample_data/imgs/paint_1.JPG', pilmode='RGB')
33
+ img_b = imageio.imread('./sample_data/imgs/paint_2.jpg', pilmode='RGB')
34
+ rep_img = imageio.imread('./sample_data/imgs/Meisje_met_de_parel.jpg', pilmode='RGB')
35
+ rep_mask = np.ones(rep_img.shape[:2])
36
+ lu_corner = [932, 1025]
37
+ ru_corner = [2469, 901]
38
+ lb_corner = [908, 2927]
39
+ rb_corner = [2436, 3080]
40
+ queries = np.array([lu_corner, ru_corner, lb_corner, rb_corner]).astype(np.float32)
41
+ rep_coord = np.array([[0, 0], [rep_img.shape[1], 0], [0, rep_img.shape[0]], [rep_img.shape[1], rep_img.shape[0]]]).astype(np.float32)
42
+
43
+ engine = SparseEngine(model, 32, mode='stretching')
44
+ corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, queries_a=queries, force=True)
45
+
46
+ T = cv2.getPerspectiveTransform(rep_coord, corrs[:, 2:].astype(np.float32))
47
+ vmask = cv2.warpPerspective(rep_mask, T, (img_b.shape[1], img_b.shape[0])) > 0
48
+ warped = cv2.warpPerspective(rep_img, T, (img_b.shape[1], img_b.shape[0]))
49
+ out = warped * vmask[..., None] + img_b * (~vmask[..., None])
50
+
51
+ f, axarr = plt.subplots(1, 4)
52
+ axarr[0].imshow(rep_img)
53
+ axarr[0].title.set_text('Virtual Paint')
54
+ axarr[0].axis('off')
55
+ axarr[1].imshow(img_a)
56
+ axarr[1].title.set_text('Annotated Frame')
57
+ axarr[1].axis('off')
58
+ axarr[2].imshow(img_b)
59
+ axarr[2].title.set_text('Target Frame')
60
+ axarr[2].axis('off')
61
+ axarr[3].imshow(out)
62
+ axarr[3].title.set_text('Overlay')
63
+ axarr[3].axis('off')
64
+ plt.show()
65
+
66
+
67
+ if __name__ == "__main__":
68
+ parser = argparse.ArgumentParser()
69
+ set_COTR_arguments(parser)
70
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
71
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
72
+
73
+ opt = parser.parse_args()
74
+ opt.command = ' '.join(sys.argv)
75
+
76
+ layer_2_channels = {'layer1': 256,
77
+ 'layer2': 512,
78
+ 'layer3': 1024,
79
+ 'layer4': 2048, }
80
+ opt.dim_feedforward = layer_2_channels[opt.layer]
81
+ if opt.load_weights:
82
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
83
+ print_opt(opt)
84
+ main(opt)
third_party/COTR/demo_reconstruction.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ COTR two view reconstruction with known extrinsic/intrinsic demo
3
+ '''
4
+ import argparse
5
+ import os
6
+ import time
7
+
8
+ import numpy as np
9
+ import torch
10
+ import imageio
11
+ import open3d as o3d
12
+
13
+ from COTR.utils import utils, debug_utils
14
+ from COTR.models import build_model
15
+ from COTR.options.options import *
16
+ from COTR.options.options_utils import *
17
+ from COTR.inference.sparse_engine import SparseEngine, FasterSparseEngine
18
+ from COTR.projector import pcd_projector
19
+
20
+ utils.fix_randomness(0)
21
+ torch.set_grad_enabled(False)
22
+
23
+
24
+ def triangulate_rays_to_pcd(center_a, dir_a, center_b, dir_b):
25
+ A = center_a
26
+ a = dir_a / np.linalg.norm(dir_a, axis=1, keepdims=True)
27
+ B = center_b
28
+ b = dir_b / np.linalg.norm(dir_b, axis=1, keepdims=True)
29
+ c = B - A
30
+ D = A + a * ((-np.sum(a * b, axis=1) * np.sum(b * c, axis=1) + np.sum(a * c, axis=1) * np.sum(b * b, axis=1)) / (np.sum(a * a, axis=1) * np.sum(b * b, axis=1) - np.sum(a * b, axis=1) * np.sum(a * b, axis=1)))[..., None]
31
+ return D
32
+
33
+
34
+ def main(opt):
35
+ model = build_model(opt)
36
+ model = model.cuda()
37
+ weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
38
+ utils.safe_load_weights(model, weights)
39
+ model = model.eval()
40
+
41
+ img_a = imageio.imread('./sample_data/imgs/img_0.jpg', pilmode='RGB')
42
+ img_b = imageio.imread('./sample_data/imgs/img_1.jpg', pilmode='RGB')
43
+
44
+ if opt.faster_infer:
45
+ engine = FasterSparseEngine(model, 32, mode='tile')
46
+ else:
47
+ engine = SparseEngine(model, 32, mode='tile')
48
+ t0 = time.time()
49
+ corrs = engine.cotr_corr_multiscale_with_cycle_consistency(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=opt.max_corrs, queries_a=None)
50
+ t1 = time.time()
51
+ print(f'spent {t1-t0} seconds for {opt.max_corrs} correspondences.')
52
+
53
+ camera_a = np.load('./sample_data/camera_0.npy', allow_pickle=True).item()
54
+ camera_b = np.load('./sample_data/camera_1.npy', allow_pickle=True).item()
55
+ center_a = camera_a['cam_center']
56
+ center_b = camera_b['cam_center']
57
+ rays_a = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(corrs[:, :2], np.ones([corrs.shape[0], 1]) * 2, camera_a['intrinsic'], motion=camera_a['c2w'])
58
+ rays_b = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(corrs[:, 2:], np.ones([corrs.shape[0], 1]) * 2, camera_b['intrinsic'], motion=camera_b['c2w'])
59
+ dir_a = rays_a - center_a
60
+ dir_b = rays_b - center_b
61
+ center_a = np.array([center_a] * corrs.shape[0])
62
+ center_b = np.array([center_b] * corrs.shape[0])
63
+ points = triangulate_rays_to_pcd(center_a, dir_a, center_b, dir_b)
64
+ colors = (img_a[tuple(np.floor(corrs[:, :2]).astype(int)[:, ::-1].T)] / 255 + img_b[tuple(np.floor(corrs[:, 2:]).astype(int)[:, ::-1].T)] / 255) / 2
65
+ colors = np.array(colors)
66
+
67
+ pcd = o3d.geometry.PointCloud()
68
+ pcd.points = o3d.utility.Vector3dVector(points)
69
+ pcd.colors = o3d.utility.Vector3dVector(colors)
70
+ o3d.visualization.draw_geometries([pcd])
71
+
72
+
73
+ if __name__ == "__main__":
74
+ parser = argparse.ArgumentParser()
75
+ set_COTR_arguments(parser)
76
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
77
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
78
+ parser.add_argument('--max_corrs', type=int, default=2048, help='number of correspondences')
79
+ parser.add_argument('--faster_infer', type=str2bool, default=False, help='use fatser inference')
80
+
81
+ opt = parser.parse_args()
82
+ opt.command = ' '.join(sys.argv)
83
+
84
+ layer_2_channels = {'layer1': 256,
85
+ 'layer2': 512,
86
+ 'layer3': 1024,
87
+ 'layer4': 2048, }
88
+ opt.dim_feedforward = layer_2_channels[opt.layer]
89
+ if opt.load_weights:
90
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
91
+ print_opt(opt)
92
+ main(opt)
third_party/COTR/demo_single_pair.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ COTR demo for a single image pair
3
+ '''
4
+ import argparse
5
+ import os
6
+ import time
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ import imageio
12
+ import matplotlib.pyplot as plt
13
+
14
+ from COTR.utils import utils, debug_utils
15
+ from COTR.models import build_model
16
+ from COTR.options.options import *
17
+ from COTR.options.options_utils import *
18
+ from COTR.inference.inference_helper import triangulate_corr
19
+ from COTR.inference.sparse_engine import SparseEngine
20
+
21
+ utils.fix_randomness(0)
22
+ torch.set_grad_enabled(False)
23
+
24
+
25
+ def main(opt):
26
+ model = build_model(opt)
27
+ model = model.cuda()
28
+ weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
29
+ utils.safe_load_weights(model, weights)
30
+ model = model.eval()
31
+
32
+ img_a = imageio.imread('./sample_data/imgs/cathedral_1.jpg', pilmode='RGB')
33
+ img_b = imageio.imread('./sample_data/imgs/cathedral_2.jpg', pilmode='RGB')
34
+
35
+ engine = SparseEngine(model, 32, mode='tile')
36
+ t0 = time.time()
37
+ corrs = engine.cotr_corr_multiscale_with_cycle_consistency(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=opt.max_corrs, queries_a=None)
38
+ t1 = time.time()
39
+
40
+ utils.visualize_corrs(img_a, img_b, corrs)
41
+ print(f'spent {t1-t0} seconds for {opt.max_corrs} correspondences.')
42
+ dense = triangulate_corr(corrs, img_a.shape, img_b.shape)
43
+ warped = cv2.remap(img_b, dense[..., 0].astype(np.float32), dense[..., 1].astype(np.float32), interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
44
+ plt.imshow(warped / 255 * 0.5 + img_a / 255 * 0.5)
45
+ plt.show()
46
+
47
+
48
+ if __name__ == "__main__":
49
+ parser = argparse.ArgumentParser()
50
+ set_COTR_arguments(parser)
51
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
52
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
53
+ parser.add_argument('--max_corrs', type=int, default=100, help='number of correspondences')
54
+
55
+ opt = parser.parse_args()
56
+ opt.command = ' '.join(sys.argv)
57
+
58
+ layer_2_channels = {'layer1': 256,
59
+ 'layer2': 512,
60
+ 'layer3': 1024,
61
+ 'layer4': 2048, }
62
+ opt.dim_feedforward = layer_2_channels[opt.layer]
63
+ if opt.load_weights:
64
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
65
+ print_opt(opt)
66
+ main(opt)
third_party/COTR/demo_wbs.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ Manually passing scale to COTR, skip the scale difference estimation.
3
+ '''
4
+ import argparse
5
+ import os
6
+ import time
7
+
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ import imageio
12
+ from scipy.spatial import distance_matrix
13
+ import matplotlib.pyplot as plt
14
+
15
+ from COTR.utils import utils, debug_utils
16
+ from COTR.models import build_model
17
+ from COTR.options.options import *
18
+ from COTR.options.options_utils import *
19
+ from COTR.inference.sparse_engine import SparseEngine
20
+
21
+ utils.fix_randomness(0)
22
+ torch.set_grad_enabled(False)
23
+
24
+
25
+ def main(opt):
26
+ model = build_model(opt)
27
+ model = model.cuda()
28
+ weights = torch.load(opt.load_weights_path)['model_state_dict']
29
+ utils.safe_load_weights(model, weights)
30
+ model = model.eval()
31
+
32
+ img_a = imageio.imread('./sample_data/imgs/petrzin_01.png')
33
+ img_b = imageio.imread('./sample_data/imgs/petrzin_02.png')
34
+ img_a_area = 1.0
35
+ img_b_area = 1.0
36
+ gt_corrs = np.loadtxt('./sample_data/petrzin_pts.txt')
37
+ kp_a = gt_corrs[:, :2]
38
+ kp_b = gt_corrs[:, 2:]
39
+
40
+ engine = SparseEngine(model, 32, mode='tile')
41
+ t0 = time.time()
42
+ corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.75, 0.1, 4), 1, max_corrs=kp_a.shape[0], queries_a=kp_a, force=True, areas=[img_a_area, img_b_area])
43
+ t1 = time.time()
44
+ print(f'COTR spent {t1-t0} seconds.')
45
+
46
+ utils.visualize_corrs(img_a, img_b, corrs)
47
+ plt.imshow(img_b)
48
+ plt.scatter(kp_b[:,0], kp_b[:,1])
49
+ plt.scatter(corrs[:,2], corrs[:,3])
50
+ plt.plot(np.stack([kp_b[:,0], corrs[:,2]], axis=1).T, np.stack([kp_b[:,1], corrs[:,3]], axis=1).T, color=[1,0,0])
51
+ plt.show()
52
+
53
+
54
+ if __name__ == "__main__":
55
+ parser = argparse.ArgumentParser()
56
+ set_COTR_arguments(parser)
57
+ parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
58
+ parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
59
+
60
+ opt = parser.parse_args()
61
+ opt.command = ' '.join(sys.argv)
62
+
63
+ layer_2_channels = {'layer1': 256,
64
+ 'layer2': 512,
65
+ 'layer3': 1024,
66
+ 'layer4': 2048, }
67
+ opt.dim_feedforward = layer_2_channels[opt.layer]
68
+ if opt.load_weights:
69
+ opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
70
+ print_opt(opt)
71
+ main(opt)
third_party/COTR/environment.yml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: cotr_env
2
+ channels:
3
+ - pytorch
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=main
7
+ - backcall=0.2.0=pyhd3eb1b0_0
8
+ - blas=1.0=mkl
9
+ - bzip2=1.0.8=h7b6447c_0
10
+ - ca-certificates=2021.4.13=h06a4308_1
11
+ - cairo=1.16.0=hf32fb01_1
12
+ - certifi=2020.12.5=py37h06a4308_0
13
+ - cudatoolkit=10.2.89=hfd86e86_1
14
+ - cycler=0.10.0=py37_0
15
+ - dbus=1.13.18=hb2f20db_0
16
+ - decorator=5.0.6=pyhd3eb1b0_0
17
+ - expat=2.3.0=h2531618_2
18
+ - ffmpeg=4.0=hcdf2ecd_0
19
+ - fontconfig=2.13.1=h6c09931_0
20
+ - freeglut=3.0.0=hf484d3e_5
21
+ - freetype=2.10.4=h5ab3b9f_0
22
+ - glib=2.68.1=h36276a3_0
23
+ - graphite2=1.3.14=h23475e2_0
24
+ - gst-plugins-base=1.14.0=h8213a91_2
25
+ - gstreamer=1.14.0=h28cd5cc_2
26
+ - harfbuzz=1.8.8=hffaf4a1_0
27
+ - hdf5=1.10.2=hba1933b_1
28
+ - icu=58.2=he6710b0_3
29
+ - imageio=2.9.0=pyhd3eb1b0_0
30
+ - intel-openmp=2021.2.0=h06a4308_610
31
+ - ipython=7.22.0=py37hb070fc8_0
32
+ - ipython_genutils=0.2.0=pyhd3eb1b0_1
33
+ - jasper=2.0.14=h07fcdf6_1
34
+ - jedi=0.17.0=py37_0
35
+ - jpeg=9b=h024ee3a_2
36
+ - kiwisolver=1.3.1=py37h2531618_0
37
+ - lcms2=2.12=h3be6417_0
38
+ - ld_impl_linux-64=2.33.1=h53a641e_7
39
+ - libffi=3.3=he6710b0_2
40
+ - libgcc-ng=9.1.0=hdf63c60_0
41
+ - libgfortran-ng=7.3.0=hdf63c60_0
42
+ - libglu=9.0.0=hf484d3e_1
43
+ - libopencv=3.4.2=hb342d67_1
44
+ - libopus=1.3.1=h7b6447c_0
45
+ - libpng=1.6.37=hbc83047_0
46
+ - libstdcxx-ng=9.1.0=hdf63c60_0
47
+ - libtiff=4.1.0=h2733197_1
48
+ - libuuid=1.0.3=h1bed415_2
49
+ - libuv=1.40.0=h7b6447c_0
50
+ - libvpx=1.7.0=h439df22_0
51
+ - libxcb=1.14=h7b6447c_0
52
+ - libxml2=2.9.10=hb55368b_3
53
+ - lz4-c=1.9.3=h2531618_0
54
+ - matplotlib=3.3.4=py37h06a4308_0
55
+ - matplotlib-base=3.3.4=py37h62a2d02_0
56
+ - mkl=2020.2=256
57
+ - mkl-service=2.3.0=py37he8ac12f_0
58
+ - mkl_fft=1.3.0=py37h54f3939_0
59
+ - mkl_random=1.1.1=py37h0573a6f_0
60
+ - ncurses=6.2=he6710b0_1
61
+ - ninja=1.10.2=hff7bd54_1
62
+ - numpy=1.19.2=py37h54aff64_0
63
+ - numpy-base=1.19.2=py37hfa32c7d_0
64
+ - olefile=0.46=py37_0
65
+ - opencv=3.4.2=py37h6fd60c2_1
66
+ - openssl=1.1.1k=h27cfd23_0
67
+ - parso=0.8.2=pyhd3eb1b0_0
68
+ - pcre=8.44=he6710b0_0
69
+ - pexpect=4.8.0=pyhd3eb1b0_3
70
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
71
+ - pillow=8.2.0=py37he98fc37_0
72
+ - pip=21.0.1=py37h06a4308_0
73
+ - pixman=0.40.0=h7b6447c_0
74
+ - prompt-toolkit=3.0.17=pyh06a4308_0
75
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
76
+ - py-opencv=3.4.2=py37hb342d67_1
77
+ - pygments=2.8.1=pyhd3eb1b0_0
78
+ - pyparsing=2.4.7=pyhd3eb1b0_0
79
+ - pyqt=5.9.2=py37h05f1152_2
80
+ - python=3.7.10=hdb3f193_0
81
+ - python-dateutil=2.8.1=pyhd3eb1b0_0
82
+ - pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0
83
+ - qt=5.9.7=h5867ecd_1
84
+ - readline=8.1=h27cfd23_0
85
+ - scipy=1.2.1=py37h7c811a0_0
86
+ - setuptools=52.0.0=py37h06a4308_0
87
+ - sip=4.19.8=py37hf484d3e_0
88
+ - six=1.15.0=py37h06a4308_0
89
+ - sqlite=3.35.4=hdfb4753_0
90
+ - tk=8.6.10=hbc83047_0
91
+ - torchaudio=0.7.2=py37
92
+ - torchvision=0.8.2=py37_cu102
93
+ - tornado=6.1=py37h27cfd23_0
94
+ - tqdm=4.59.0=pyhd3eb1b0_1
95
+ - traitlets=5.0.5=pyhd3eb1b0_0
96
+ - typing_extensions=3.7.4.3=pyha847dfd_0
97
+ - vispy=0.5.3=py37hee6b756_0
98
+ - wcwidth=0.2.5=py_0
99
+ - wheel=0.36.2=pyhd3eb1b0_0
100
+ - xz=5.2.5=h7b6447c_0
101
+ - zlib=1.2.11=h7b6447c_3
102
+ - zstd=1.4.9=haebb681_0
103
+ - pip:
104
+ - tables==3.6.1
third_party/COTR/out/.DS_Store ADDED
Binary file (6.15 kB). View file
 
third_party/COTR/out/.placeholder ADDED
File without changes
third_party/COTR/out/default/checkpoint.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abfa1183408dc566535146b41508ed02084d5f5d1a150f5c188ee479463d6d5c
3
+ size 219363688