Spaces:
Running
Running
Realcat
commited on
Commit
·
10dcc2e
1
Parent(s):
d21720c
add: COTR(https://github.com/ubc-vision/COTR)
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +1 -0
- common/app_class.py +3 -1
- common/config.yaml +10 -0
- common/utils.py +33 -5
- env-docker.txt +2 -2
- hloc/match_dense.py +38 -0
- hloc/matchers/cotr.py +77 -0
- third_party/COTR/.gitignore +1 -0
- third_party/COTR/COTR/cameras/camera_pose.py +164 -0
- third_party/COTR/COTR/cameras/capture.py +432 -0
- third_party/COTR/COTR/cameras/pinhole_camera.py +73 -0
- third_party/COTR/COTR/datasets/colmap_helper.py +312 -0
- third_party/COTR/COTR/datasets/cotr_dataset.py +243 -0
- third_party/COTR/COTR/datasets/megadepth_dataset.py +140 -0
- third_party/COTR/COTR/global_configs/__init__.py +10 -0
- third_party/COTR/COTR/global_configs/commons.json +1 -0
- third_party/COTR/COTR/global_configs/dataset_config.json +41 -0
- third_party/COTR/COTR/inference/inference_helper.py +311 -0
- third_party/COTR/COTR/inference/refinement_task.py +191 -0
- third_party/COTR/COTR/inference/sparse_engine.py +427 -0
- third_party/COTR/COTR/models/__init__.py +10 -0
- third_party/COTR/COTR/models/backbone.py +135 -0
- third_party/COTR/COTR/models/cotr_model.py +51 -0
- third_party/COTR/COTR/models/misc.py +112 -0
- third_party/COTR/COTR/models/position_encoding.py +83 -0
- third_party/COTR/COTR/models/transformer.py +228 -0
- third_party/COTR/COTR/options/options.py +52 -0
- third_party/COTR/COTR/options/options_utils.py +108 -0
- third_party/COTR/COTR/projector/pcd_projector.py +210 -0
- third_party/COTR/COTR/sfm_scenes/knn_search.py +56 -0
- third_party/COTR/COTR/sfm_scenes/sfm_scenes.py +87 -0
- third_party/COTR/COTR/trainers/base_trainer.py +111 -0
- third_party/COTR/COTR/trainers/cotr_trainer.py +200 -0
- third_party/COTR/COTR/trainers/tensorboard_helper.py +97 -0
- third_party/COTR/COTR/transformations/transform_basics.py +114 -0
- third_party/COTR/COTR/transformations/transformations.py +1951 -0
- third_party/COTR/COTR/utils/constants.py +3 -0
- third_party/COTR/COTR/utils/debug_utils.py +15 -0
- third_party/COTR/COTR/utils/utils.py +271 -0
- third_party/COTR/LICENSE +201 -0
- third_party/COTR/demo_face.py +69 -0
- third_party/COTR/demo_guided_matching.py +85 -0
- third_party/COTR/demo_homography.py +84 -0
- third_party/COTR/demo_reconstruction.py +92 -0
- third_party/COTR/demo_single_pair.py +66 -0
- third_party/COTR/demo_wbs.py +71 -0
- third_party/COTR/environment.yml +104 -0
- third_party/COTR/out/.DS_Store +0 -0
- third_party/COTR/out/.placeholder +0 -0
- third_party/COTR/out/default/checkpoint.pth.tar +3 -0
README.md
CHANGED
@@ -56,6 +56,7 @@ The tool currently supports various popular image matching algorithms, namely:
|
|
56 |
- [x] [LANet](https://github.com/wangch-g/lanet), ACCV 2022
|
57 |
- [ ] [LISRD](https://github.com/rpautrat/LISRD), ECCV 2022
|
58 |
- [ ] [REKD](https://github.com/bluedream1121/REKD), CVPR 2022
|
|
|
59 |
- [x] [ALIKE](https://github.com/Shiaoming/ALIKE), TMM 2022
|
60 |
- [x] [RoRD](https://github.com/UditSinghParihar/RoRD), IROS 2021
|
61 |
- [x] [SGMNet](https://github.com/vdvchen/SGMNet), ICCV 2021
|
|
|
56 |
- [x] [LANet](https://github.com/wangch-g/lanet), ACCV 2022
|
57 |
- [ ] [LISRD](https://github.com/rpautrat/LISRD), ECCV 2022
|
58 |
- [ ] [REKD](https://github.com/bluedream1121/REKD), CVPR 2022
|
59 |
+
- [x] [CoTR](https://github.com/ubc-vision/COTR), ICCV 2021
|
60 |
- [x] [ALIKE](https://github.com/Shiaoming/ALIKE), TMM 2022
|
61 |
- [x] [RoRD](https://github.com/UditSinghParihar/RoRD), IROS 2021
|
62 |
- [x] [SGMNet](https://github.com/vdvchen/SGMNet), ICCV 2021
|
common/app_class.py
CHANGED
@@ -300,6 +300,7 @@ class ImageMatchingApp:
|
|
300 |
fn=run_ransac,
|
301 |
inputs=[
|
302 |
state_cache,
|
|
|
303 |
ransac_method,
|
304 |
ransac_reproj_threshold,
|
305 |
ransac_confidence,
|
@@ -308,6 +309,7 @@ class ImageMatchingApp:
|
|
308 |
outputs=[
|
309 |
output_matches_ransac,
|
310 |
matches_result_info,
|
|
|
311 |
],
|
312 |
)
|
313 |
|
@@ -457,7 +459,7 @@ class ImageMatchingApp:
|
|
457 |
return gr.Markdown(markdown_table)
|
458 |
elif style == "tab":
|
459 |
for k, v in cfg.items():
|
460 |
-
if not v["info"]
|
461 |
continue
|
462 |
data.append(
|
463 |
[
|
|
|
300 |
fn=run_ransac,
|
301 |
inputs=[
|
302 |
state_cache,
|
303 |
+
choice_geometry_type,
|
304 |
ransac_method,
|
305 |
ransac_reproj_threshold,
|
306 |
ransac_confidence,
|
|
|
309 |
outputs=[
|
310 |
output_matches_ransac,
|
311 |
matches_result_info,
|
312 |
+
output_wrapped,
|
313 |
],
|
314 |
)
|
315 |
|
|
|
459 |
return gr.Markdown(markdown_table)
|
460 |
elif style == "tab":
|
461 |
for k, v in cfg.items():
|
462 |
+
if not v["info"].get("display", True):
|
463 |
continue
|
464 |
data.append(
|
465 |
[
|
common/config.yaml
CHANGED
@@ -46,6 +46,16 @@ matcher_zoo:
|
|
46 |
paper: https://arxiv.org/pdf/2104.00680
|
47 |
project: https://zju3dv.github.io/loftr
|
48 |
display: true
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
topicfm:
|
50 |
matcher: topicfm
|
51 |
dense: true
|
|
|
46 |
paper: https://arxiv.org/pdf/2104.00680
|
47 |
project: https://zju3dv.github.io/loftr
|
48 |
display: true
|
49 |
+
cotr:
|
50 |
+
matcher: cotr
|
51 |
+
dense: true
|
52 |
+
info:
|
53 |
+
name: CoTR #dispaly name
|
54 |
+
source: "ICCV 2021"
|
55 |
+
github: https://github.com/ubc-vision/COTR
|
56 |
+
paper: https://arxiv.org/abs/2103.14167
|
57 |
+
project: null
|
58 |
+
display: true
|
59 |
topicfm:
|
60 |
matcher: topicfm
|
61 |
dense: true
|
common/utils.py
CHANGED
@@ -443,6 +443,7 @@ def generate_warp_images(
|
|
443 |
|
444 |
def run_ransac(
|
445 |
state_cache: Dict[str, Any],
|
|
|
446 |
ransac_method: str = DEFAULT_RANSAC_METHOD,
|
447 |
ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
|
448 |
ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
|
@@ -493,11 +494,32 @@ def run_ransac(
|
|
493 |
)
|
494 |
logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
|
495 |
t1 = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
496 |
num_matches_raw = state_cache["num_matches_raw"]
|
497 |
-
return
|
498 |
-
|
499 |
-
|
500 |
-
|
|
|
|
|
|
|
|
|
501 |
|
502 |
|
503 |
def run_matching(
|
@@ -666,7 +688,13 @@ def run_matching(
|
|
666 |
|
667 |
t1 = time.time()
|
668 |
# plot wrapped images
|
669 |
-
geom_info = compute_geometry(
|
|
|
|
|
|
|
|
|
|
|
|
|
670 |
output_wrapped, _ = generate_warp_images(
|
671 |
pred["image0_orig"],
|
672 |
pred["image1_orig"],
|
|
|
443 |
|
444 |
def run_ransac(
|
445 |
state_cache: Dict[str, Any],
|
446 |
+
choice_geometry_type: str,
|
447 |
ransac_method: str = DEFAULT_RANSAC_METHOD,
|
448 |
ransac_reproj_threshold: int = DEFAULT_RANSAC_REPROJ_THRESHOLD,
|
449 |
ransac_confidence: float = DEFAULT_RANSAC_CONFIDENCE,
|
|
|
494 |
)
|
495 |
logger.info(f"Display matches done using: {time.time()-t1:.3f}s")
|
496 |
t1 = time.time()
|
497 |
+
|
498 |
+
# compute warp images
|
499 |
+
geom_info = compute_geometry(
|
500 |
+
state_cache,
|
501 |
+
ransac_method=ransac_method,
|
502 |
+
ransac_reproj_threshold=ransac_reproj_threshold,
|
503 |
+
ransac_confidence=ransac_confidence,
|
504 |
+
ransac_max_iter=ransac_max_iter,
|
505 |
+
)
|
506 |
+
output_wrapped, _ = generate_warp_images(
|
507 |
+
state_cache["image0_orig"],
|
508 |
+
state_cache["image1_orig"],
|
509 |
+
{"geom_info": geom_info},
|
510 |
+
choice_geometry_type,
|
511 |
+
)
|
512 |
+
plt.close("all")
|
513 |
+
|
514 |
num_matches_raw = state_cache["num_matches_raw"]
|
515 |
+
return (
|
516 |
+
output_matches_ransac,
|
517 |
+
{
|
518 |
+
"num_matches_raw": num_matches_raw,
|
519 |
+
"num_matches_ransac": num_matches_ransac,
|
520 |
+
},
|
521 |
+
output_wrapped,
|
522 |
+
)
|
523 |
|
524 |
|
525 |
def run_matching(
|
|
|
688 |
|
689 |
t1 = time.time()
|
690 |
# plot wrapped images
|
691 |
+
geom_info = compute_geometry(
|
692 |
+
pred,
|
693 |
+
ransac_method=ransac_method,
|
694 |
+
ransac_reproj_threshold=ransac_reproj_threshold,
|
695 |
+
ransac_confidence=ransac_confidence,
|
696 |
+
ransac_max_iter=ransac_max_iter,
|
697 |
+
)
|
698 |
output_wrapped, _ = generate_warp_images(
|
699 |
pred["image0_orig"],
|
700 |
pred["image1_orig"],
|
env-docker.txt
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
e2cnn==0.2.3
|
2 |
einops==0.6.1
|
3 |
gdown==4.7.1
|
4 |
-
gradio==
|
5 |
-
gradio_client==0.
|
6 |
h5py==3.9.0
|
7 |
imageio==2.31.1
|
8 |
Jinja2==3.1.2
|
|
|
1 |
e2cnn==0.2.3
|
2 |
einops==0.6.1
|
3 |
gdown==4.7.1
|
4 |
+
gradio==4.28.3
|
5 |
+
gradio_client==0.16.0
|
6 |
h5py==3.9.0
|
7 |
imageio==2.31.1
|
8 |
Jinja2==3.1.2
|
hloc/match_dense.py
CHANGED
@@ -28,6 +28,44 @@ confs = {
|
|
28 |
"max_error": 1, # max error for assigned keypoints (in px)
|
29 |
"cell_size": 1, # size of quantization patch (max 1 kp/patch)
|
30 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
# Semi-scalable loftr which limits detected keypoints
|
32 |
"loftr_aachen": {
|
33 |
"output": "matches-loftr_aachen",
|
|
|
28 |
"max_error": 1, # max error for assigned keypoints (in px)
|
29 |
"cell_size": 1, # size of quantization patch (max 1 kp/patch)
|
30 |
},
|
31 |
+
# "loftr_quadtree": {
|
32 |
+
# "output": "matches-loftr-quadtree",
|
33 |
+
# "model": {
|
34 |
+
# "name": "quadtree",
|
35 |
+
# "weights": "outdoor",
|
36 |
+
# "max_keypoints": 2000,
|
37 |
+
# "match_threshold": 0.2,
|
38 |
+
# },
|
39 |
+
# "preprocessing": {
|
40 |
+
# "grayscale": True,
|
41 |
+
# "resize_max": 1024,
|
42 |
+
# "dfactor": 8,
|
43 |
+
# "width": 640,
|
44 |
+
# "height": 480,
|
45 |
+
# "force_resize": True,
|
46 |
+
# },
|
47 |
+
# "max_error": 1, # max error for assigned keypoints (in px)
|
48 |
+
# "cell_size": 1, # size of quantization patch (max 1 kp/patch)
|
49 |
+
# },
|
50 |
+
"cotr": {
|
51 |
+
"output": "matches-cotr",
|
52 |
+
"model": {
|
53 |
+
"name": "cotr",
|
54 |
+
"weights": "out/default",
|
55 |
+
"max_keypoints": 2000,
|
56 |
+
"match_threshold": 0.2,
|
57 |
+
},
|
58 |
+
"preprocessing": {
|
59 |
+
"grayscale": False,
|
60 |
+
"resize_max": 1024,
|
61 |
+
"dfactor": 8,
|
62 |
+
"width": 640,
|
63 |
+
"height": 480,
|
64 |
+
"force_resize": True,
|
65 |
+
},
|
66 |
+
"max_error": 1, # max error for assigned keypoints (in px)
|
67 |
+
"cell_size": 1, # size of quantization patch (max 1 kp/patch)
|
68 |
+
},
|
69 |
# Semi-scalable loftr which limits detected keypoints
|
70 |
"loftr_aachen": {
|
71 |
"output": "matches-loftr_aachen",
|
hloc/matchers/cotr.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
import torch
|
4 |
+
import warnings
|
5 |
+
import numpy as np
|
6 |
+
from pathlib import Path
|
7 |
+
from torchvision.transforms import ToPILImage
|
8 |
+
from ..utils.base_model import BaseModel
|
9 |
+
|
10 |
+
sys.path.append(str(Path(__file__).parent / "../../third_party/COTR"))
|
11 |
+
from COTR.utils import utils as utils_cotr
|
12 |
+
from COTR.models import build_model
|
13 |
+
from COTR.options.options import *
|
14 |
+
from COTR.options.options_utils import *
|
15 |
+
from COTR.inference.inference_helper import triangulate_corr
|
16 |
+
from COTR.inference.sparse_engine import SparseEngine
|
17 |
+
|
18 |
+
utils_cotr.fix_randomness(0)
|
19 |
+
torch.set_grad_enabled(False)
|
20 |
+
|
21 |
+
cotr_path = Path(__file__).parent / "../../third_party/COTR"
|
22 |
+
|
23 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
24 |
+
|
25 |
+
|
26 |
+
class COTR(BaseModel):
|
27 |
+
default_conf = {
|
28 |
+
"weights": "out/default",
|
29 |
+
"match_threshold": 0.2,
|
30 |
+
"max_keypoints": -1,
|
31 |
+
}
|
32 |
+
required_inputs = ["image0", "image1"]
|
33 |
+
|
34 |
+
def _init(self, conf):
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
set_COTR_arguments(parser)
|
37 |
+
opt = parser.parse_args()
|
38 |
+
opt.command = " ".join(sys.argv)
|
39 |
+
opt.load_weights_path = str(
|
40 |
+
cotr_path / conf["weights"] / "checkpoint.pth.tar"
|
41 |
+
)
|
42 |
+
|
43 |
+
layer_2_channels = {
|
44 |
+
"layer1": 256,
|
45 |
+
"layer2": 512,
|
46 |
+
"layer3": 1024,
|
47 |
+
"layer4": 2048,
|
48 |
+
}
|
49 |
+
opt.dim_feedforward = layer_2_channels[opt.layer]
|
50 |
+
|
51 |
+
model = build_model(opt)
|
52 |
+
model = model.to(device)
|
53 |
+
weights = torch.load(opt.load_weights_path, map_location="cpu")[
|
54 |
+
"model_state_dict"
|
55 |
+
]
|
56 |
+
utils_cotr.safe_load_weights(model, weights)
|
57 |
+
self.net = model.eval()
|
58 |
+
self.to_pil_func = ToPILImage(mode="RGB")
|
59 |
+
|
60 |
+
def _forward(self, data):
|
61 |
+
img_a = np.array(self.to_pil_func(data["image0"][0].cpu()))
|
62 |
+
img_b = np.array(self.to_pil_func(data["image1"][0].cpu()))
|
63 |
+
corrs = SparseEngine(
|
64 |
+
self.net, 32, mode="tile"
|
65 |
+
).cotr_corr_multiscale_with_cycle_consistency(
|
66 |
+
img_a,
|
67 |
+
img_b,
|
68 |
+
np.linspace(0.5, 0.0625, 4),
|
69 |
+
1,
|
70 |
+
max_corrs=self.conf["max_keypoints"],
|
71 |
+
queries_a=None,
|
72 |
+
)
|
73 |
+
pred = {
|
74 |
+
"keypoints0": torch.from_numpy(corrs[:,:2]),
|
75 |
+
"keypoints1": torch.from_numpy(corrs[:,2:]),
|
76 |
+
}
|
77 |
+
return pred
|
third_party/COTR/.gitignore
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
*.pyc
|
third_party/COTR/COTR/cameras/camera_pose.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Extrinsic camera pose
|
3 |
+
'''
|
4 |
+
import math
|
5 |
+
import copy
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from COTR.transformations import transformations
|
10 |
+
from COTR.transformations.transform_basics import Translation, Rotation, UnstableRotation
|
11 |
+
|
12 |
+
|
13 |
+
class CameraPose():
|
14 |
+
def __init__(self, t: Translation, r: Rotation):
|
15 |
+
'''
|
16 |
+
WARN: World 2 cam
|
17 |
+
Translation and rotation are world to camera
|
18 |
+
translation_vector is not the coordinate of the camera in world space.
|
19 |
+
'''
|
20 |
+
assert isinstance(t, Translation)
|
21 |
+
assert isinstance(r, Rotation) or isinstance(r, UnstableRotation)
|
22 |
+
self.t = t
|
23 |
+
self.r = r
|
24 |
+
|
25 |
+
def __str__(self):
|
26 |
+
string = f'center in world: {self.camera_center_in_world}, translation(w2c): {self.t}, rotation(w2c): {self.r}'
|
27 |
+
return string
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def from_world_to_camera(cls, world_to_camera, unstable=False):
|
31 |
+
assert isinstance(world_to_camera, np.ndarray)
|
32 |
+
assert world_to_camera.shape == (4, 4)
|
33 |
+
vec = transformations.translation_from_matrix(world_to_camera).astype(np.float32)
|
34 |
+
t = Translation(vec)
|
35 |
+
if unstable:
|
36 |
+
r = UnstableRotation(world_to_camera)
|
37 |
+
else:
|
38 |
+
quat = transformations.quaternion_from_matrix(world_to_camera).astype(np.float32)
|
39 |
+
r = Rotation(quat)
|
40 |
+
return cls(t, r)
|
41 |
+
|
42 |
+
@classmethod
|
43 |
+
def from_camera_to_world(cls, camera_to_world, unstable=False):
|
44 |
+
assert isinstance(camera_to_world, np.ndarray)
|
45 |
+
assert camera_to_world.shape == (4, 4)
|
46 |
+
world_to_camera = np.linalg.inv(camera_to_world)
|
47 |
+
world_to_camera /= world_to_camera[3, 3]
|
48 |
+
return cls.from_world_to_camera(world_to_camera, unstable)
|
49 |
+
|
50 |
+
@classmethod
|
51 |
+
def from_pose_vector(cls, pose_vector):
|
52 |
+
t = Translation(pose_vector[:3])
|
53 |
+
r = Rotation(pose_vector[3:])
|
54 |
+
return cls(t, r)
|
55 |
+
|
56 |
+
@property
|
57 |
+
def translation_vector(self):
|
58 |
+
return self.t.translation_vector
|
59 |
+
|
60 |
+
@property
|
61 |
+
def translation_matrix(self):
|
62 |
+
return self.t.translation_matrix
|
63 |
+
|
64 |
+
@property
|
65 |
+
def quaternion(self):
|
66 |
+
'''
|
67 |
+
quaternion format (w, x, y, z)
|
68 |
+
'''
|
69 |
+
return self.r.quaternion
|
70 |
+
|
71 |
+
@property
|
72 |
+
def rotation_matrix(self):
|
73 |
+
return self.r.rotation_matrix
|
74 |
+
|
75 |
+
@property
|
76 |
+
def pose_vector(self):
|
77 |
+
'''
|
78 |
+
Pose vector is a concat of translation vector and quaternion vector
|
79 |
+
(X, Y, Z, w, x, y, z)
|
80 |
+
w2c
|
81 |
+
'''
|
82 |
+
return np.concatenate([self.translation_vector, self.quaternion])
|
83 |
+
|
84 |
+
@property
|
85 |
+
def inv_pose_vector(self):
|
86 |
+
inv_quat = transformations.quaternion_inverse(self.quaternion)
|
87 |
+
return np.concatenate([self.camera_center_in_world, inv_quat])
|
88 |
+
|
89 |
+
@property
|
90 |
+
def pose_vector_6_dof(self):
|
91 |
+
'''
|
92 |
+
Here we assuming the quaternion is normalized and we remove the W component
|
93 |
+
(X, Y, Z, x, y, z)
|
94 |
+
'''
|
95 |
+
return np.concatenate([self.translation_vector, self.quaternion[1:]])
|
96 |
+
|
97 |
+
@property
|
98 |
+
def world_to_camera(self):
|
99 |
+
M = np.matmul(self.translation_matrix, self.rotation_matrix)
|
100 |
+
M /= M[3, 3]
|
101 |
+
return M
|
102 |
+
|
103 |
+
@property
|
104 |
+
def world_to_camera_3x4(self):
|
105 |
+
M = self.world_to_camera
|
106 |
+
M = M[0:3, 0:4]
|
107 |
+
return M
|
108 |
+
|
109 |
+
@property
|
110 |
+
def extrinsic_mat(self):
|
111 |
+
return self.world_to_camera_3x4
|
112 |
+
|
113 |
+
@property
|
114 |
+
def camera_to_world(self):
|
115 |
+
M = np.linalg.inv(self.world_to_camera)
|
116 |
+
M /= M[3, 3]
|
117 |
+
return M
|
118 |
+
|
119 |
+
@property
|
120 |
+
def camera_to_world_3x4(self):
|
121 |
+
M = self.camera_to_world
|
122 |
+
M = M[0:3, 0:4]
|
123 |
+
return M
|
124 |
+
|
125 |
+
@property
|
126 |
+
def camera_center_in_world(self):
|
127 |
+
return self.camera_to_world[:3, 3]
|
128 |
+
|
129 |
+
@property
|
130 |
+
def forward(self):
|
131 |
+
return self.camera_to_world[:3, 2]
|
132 |
+
|
133 |
+
@property
|
134 |
+
def up(self):
|
135 |
+
return self.camera_to_world[:3, 1]
|
136 |
+
|
137 |
+
@property
|
138 |
+
def right(self):
|
139 |
+
return self.camera_to_world[:3, 0]
|
140 |
+
|
141 |
+
@property
|
142 |
+
def essential_matrix(self):
|
143 |
+
E = np.cross(self.rotation_matrix[:3, :3], self.camera_center_in_world)
|
144 |
+
return E / np.linalg.norm(E)
|
145 |
+
|
146 |
+
|
147 |
+
def inverse_camera_pose(cam_pose: CameraPose):
|
148 |
+
return CameraPose.from_world_to_camera(np.linalg.inv(cam_pose.world_to_camera))
|
149 |
+
|
150 |
+
|
151 |
+
def rotate_camera_pose(cam_pose, rot):
|
152 |
+
if rot == 0:
|
153 |
+
return copy.deepcopy(cam_pose)
|
154 |
+
else:
|
155 |
+
rot = rot / 180 * np.pi
|
156 |
+
sin_rot = np.sin(rot)
|
157 |
+
cos_rot = np.cos(rot)
|
158 |
+
|
159 |
+
rot_mat = np.stack([np.stack([cos_rot, -sin_rot, 0, 0], axis=-1),
|
160 |
+
np.stack([sin_rot, cos_rot, 0, 0], axis=-1),
|
161 |
+
np.stack([0, 0, 1, 0], axis=-1),
|
162 |
+
np.stack([0, 0, 0, 1], axis=-1)], axis=1)
|
163 |
+
new_world2cam = np.matmul(rot_mat, cam_pose.world_to_camera)
|
164 |
+
return CameraPose.from_world_to_camera(new_world2cam)
|
third_party/COTR/COTR/cameras/capture.py
ADDED
@@ -0,0 +1,432 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Capture from a pinhole camera
|
3 |
+
Separate the captured content and the camera...
|
4 |
+
'''
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import abc
|
9 |
+
import copy
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import torch
|
13 |
+
import numpy as np
|
14 |
+
import imageio
|
15 |
+
import PIL
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
from COTR.cameras.camera_pose import CameraPose, rotate_camera_pose
|
19 |
+
from COTR.cameras.pinhole_camera import PinholeCamera, rotate_pinhole_camera, crop_pinhole_camera
|
20 |
+
from COTR.utils import debug_utils, utils, constants
|
21 |
+
from COTR.utils.utils import Point2D
|
22 |
+
from COTR.projector import pcd_projector
|
23 |
+
from COTR.utils.constants import MAX_SIZE
|
24 |
+
from COTR.utils.utils import CropCamConfig
|
25 |
+
|
26 |
+
|
27 |
+
def crop_center_max_xy(p2d, shape):
|
28 |
+
h, w = shape
|
29 |
+
crop_x = min(h, w)
|
30 |
+
crop_y = crop_x
|
31 |
+
start_x = w // 2 - crop_x // 2
|
32 |
+
start_y = h // 2 - crop_y // 2
|
33 |
+
mask = (p2d.xy[:, 0] > start_x) & (p2d.xy[:, 0] < start_x + crop_x) & (p2d.xy[:, 1] > start_y) & (p2d.xy[:, 1] < start_y + crop_y)
|
34 |
+
out_xy = (p2d.xy - [start_x, start_y])[mask]
|
35 |
+
out = Point2D(p2d.id_3d[mask], out_xy)
|
36 |
+
return out
|
37 |
+
|
38 |
+
|
39 |
+
def crop_center_max(img):
|
40 |
+
if isinstance(img, torch.Tensor):
|
41 |
+
return crop_center_max_torch(img)
|
42 |
+
elif isinstance(img, np.ndarray):
|
43 |
+
return crop_center_max_np(img)
|
44 |
+
else:
|
45 |
+
raise ValueError
|
46 |
+
|
47 |
+
|
48 |
+
def crop_center_max_torch(img):
|
49 |
+
if len(img.shape) == 2:
|
50 |
+
h, w = img.shape
|
51 |
+
elif len(img.shape) == 3:
|
52 |
+
c, h, w = img.shape
|
53 |
+
elif len(img.shape) == 4:
|
54 |
+
b, c, h, w = img.shape
|
55 |
+
else:
|
56 |
+
raise ValueError
|
57 |
+
crop_x = min(h, w)
|
58 |
+
crop_y = crop_x
|
59 |
+
start_x = w // 2 - crop_x // 2
|
60 |
+
start_y = h // 2 - crop_y // 2
|
61 |
+
if len(img.shape) == 2:
|
62 |
+
return img[start_y:start_y + crop_y, start_x:start_x + crop_x]
|
63 |
+
elif len(img.shape) in [3, 4]:
|
64 |
+
return img[..., start_y:start_y + crop_y, start_x:start_x + crop_x]
|
65 |
+
|
66 |
+
|
67 |
+
def crop_center_max_np(img, return_starts=False):
|
68 |
+
if len(img.shape) == 2:
|
69 |
+
h, w = img.shape
|
70 |
+
elif len(img.shape) == 3:
|
71 |
+
h, w, c = img.shape
|
72 |
+
elif len(img.shape) == 4:
|
73 |
+
b, h, w, c = img.shape
|
74 |
+
else:
|
75 |
+
raise ValueError
|
76 |
+
crop_x = min(h, w)
|
77 |
+
crop_y = crop_x
|
78 |
+
start_x = w // 2 - crop_x // 2
|
79 |
+
start_y = h // 2 - crop_y // 2
|
80 |
+
if len(img.shape) == 2:
|
81 |
+
canvas = img[start_y:start_y + crop_y, start_x:start_x + crop_x]
|
82 |
+
elif len(img.shape) == 3:
|
83 |
+
canvas = img[start_y:start_y + crop_y, start_x:start_x + crop_x, :]
|
84 |
+
elif len(img.shape) == 4:
|
85 |
+
canvas = img[:, start_y:start_y + crop_y, start_x:start_x + crop_x, :]
|
86 |
+
if return_starts:
|
87 |
+
return canvas, -start_x, -start_y
|
88 |
+
else:
|
89 |
+
return canvas
|
90 |
+
|
91 |
+
|
92 |
+
def pad_to_square_np(img, till_divisible_by=1, return_starts=False):
|
93 |
+
if len(img.shape) == 2:
|
94 |
+
h, w = img.shape
|
95 |
+
elif len(img.shape) == 3:
|
96 |
+
h, w, c = img.shape
|
97 |
+
elif len(img.shape) == 4:
|
98 |
+
b, h, w, c = img.shape
|
99 |
+
else:
|
100 |
+
raise ValueError
|
101 |
+
if till_divisible_by == 1:
|
102 |
+
size = max(h, w)
|
103 |
+
else:
|
104 |
+
size = (max(h, w) + till_divisible_by) - (max(h, w) % till_divisible_by)
|
105 |
+
start_x = size // 2 - w // 2
|
106 |
+
start_y = size // 2 - h // 2
|
107 |
+
if len(img.shape) == 2:
|
108 |
+
canvas = np.zeros([size, size], dtype=img.dtype)
|
109 |
+
canvas[start_y:start_y + h, start_x:start_x + w] = img
|
110 |
+
elif len(img.shape) == 3:
|
111 |
+
canvas = np.zeros([size, size, c], dtype=img.dtype)
|
112 |
+
canvas[start_y:start_y + h, start_x:start_x + w, :] = img
|
113 |
+
elif len(img.shape) == 4:
|
114 |
+
canvas = np.zeros([b, size, size, c], dtype=img.dtype)
|
115 |
+
canvas[:, start_y:start_y + h, start_x:start_x + w, :] = img
|
116 |
+
if return_starts:
|
117 |
+
return canvas, start_x, start_y
|
118 |
+
else:
|
119 |
+
return canvas
|
120 |
+
|
121 |
+
|
122 |
+
def stretch_to_square_np(img):
|
123 |
+
size = max(*img.shape[:2])
|
124 |
+
return np.array(PIL.Image.fromarray(img).resize((size, size), resample=PIL.Image.BILINEAR))
|
125 |
+
|
126 |
+
|
127 |
+
def rotate_image(image, angle, interpolation=cv2.INTER_LINEAR):
|
128 |
+
image_center = tuple(np.array(image.shape[1::-1]) / 2)
|
129 |
+
rot_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
130 |
+
result = cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=interpolation)
|
131 |
+
return result
|
132 |
+
|
133 |
+
|
134 |
+
def read_array(path):
|
135 |
+
'''
|
136 |
+
https://github.com/colmap/colmap/blob/dev/scripts/python/read_dense.py
|
137 |
+
'''
|
138 |
+
with open(path, "rb") as fid:
|
139 |
+
width, height, channels = np.genfromtxt(fid, delimiter="&", max_rows=1,
|
140 |
+
usecols=(0, 1, 2), dtype=int)
|
141 |
+
fid.seek(0)
|
142 |
+
num_delimiter = 0
|
143 |
+
byte = fid.read(1)
|
144 |
+
while True:
|
145 |
+
if byte == b"&":
|
146 |
+
num_delimiter += 1
|
147 |
+
if num_delimiter >= 3:
|
148 |
+
break
|
149 |
+
byte = fid.read(1)
|
150 |
+
array = np.fromfile(fid, np.float32)
|
151 |
+
array = array.reshape((width, height, channels), order="F")
|
152 |
+
return np.transpose(array, (1, 0, 2)).squeeze()
|
153 |
+
|
154 |
+
|
155 |
+
################ Content ################
|
156 |
+
|
157 |
+
|
158 |
+
class CapturedContent(abc.ABC):
|
159 |
+
def __init__(self):
|
160 |
+
self._rotation = 0
|
161 |
+
|
162 |
+
@property
|
163 |
+
def rotation(self):
|
164 |
+
return self._rotation
|
165 |
+
|
166 |
+
@rotation.setter
|
167 |
+
def rotation(self, rot):
|
168 |
+
self._rotation = rot
|
169 |
+
|
170 |
+
|
171 |
+
class CapturedImage(CapturedContent):
|
172 |
+
def __init__(self, img_path, crop_cam, pinhole_cam_before=None):
|
173 |
+
super(CapturedImage, self).__init__()
|
174 |
+
assert os.path.isfile(img_path), 'file does not exist: {0}'.format(img_path)
|
175 |
+
self.crop_cam = crop_cam
|
176 |
+
self._image = None
|
177 |
+
self.img_path = img_path
|
178 |
+
self.pinhole_cam_before = pinhole_cam_before
|
179 |
+
self._p2d = None
|
180 |
+
|
181 |
+
def read_image_to_ram(self) -> int:
|
182 |
+
# raise NotImplementedError
|
183 |
+
assert self._image is None
|
184 |
+
_image = self.image
|
185 |
+
self._image = _image
|
186 |
+
return self._image.nbytes
|
187 |
+
|
188 |
+
@property
|
189 |
+
def image(self):
|
190 |
+
if self._image is not None:
|
191 |
+
_image = self._image
|
192 |
+
else:
|
193 |
+
_image = imageio.imread(self.img_path, pilmode='RGB')
|
194 |
+
if self.rotation != 0:
|
195 |
+
_image = rotate_image(_image, self.rotation)
|
196 |
+
if _image.shape[:2] != self.pinhole_cam_before.shape:
|
197 |
+
_image = np.array(PIL.Image.fromarray(_image).resize(self.pinhole_cam_before.shape[::-1], resample=PIL.Image.BILINEAR))
|
198 |
+
assert _image.shape[:2] == self.pinhole_cam_before.shape
|
199 |
+
if self.crop_cam == 'no_crop':
|
200 |
+
pass
|
201 |
+
elif self.crop_cam == 'crop_center':
|
202 |
+
_image = crop_center_max(_image)
|
203 |
+
elif self.crop_cam == 'crop_center_and_resize':
|
204 |
+
_image = crop_center_max(_image)
|
205 |
+
_image = np.array(PIL.Image.fromarray(_image).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
|
206 |
+
elif isinstance(self.crop_cam, CropCamConfig):
|
207 |
+
assert _image.shape[0] == self.crop_cam.orig_h
|
208 |
+
assert _image.shape[1] == self.crop_cam.orig_w
|
209 |
+
_image = _image[self.crop_cam.y:self.crop_cam.y + self.crop_cam.h,
|
210 |
+
self.crop_cam.x:self.crop_cam.x + self.crop_cam.w, ]
|
211 |
+
_image = np.array(PIL.Image.fromarray(_image).resize((self.crop_cam.out_w, self.crop_cam.out_h), resample=PIL.Image.BILINEAR))
|
212 |
+
assert _image.shape[:2] == (self.crop_cam.out_h, self.crop_cam.out_w)
|
213 |
+
else:
|
214 |
+
raise ValueError()
|
215 |
+
return _image
|
216 |
+
|
217 |
+
@property
|
218 |
+
def p2d(self):
|
219 |
+
if self._p2d is None:
|
220 |
+
return self._p2d
|
221 |
+
else:
|
222 |
+
_p2d = self._p2d
|
223 |
+
if self.crop_cam == 'no_crop':
|
224 |
+
pass
|
225 |
+
elif self.crop_cam == 'crop_center':
|
226 |
+
_p2d = crop_center_max_xy(_p2d, self.pinhole_cam_before.shape)
|
227 |
+
else:
|
228 |
+
raise ValueError()
|
229 |
+
return _p2d
|
230 |
+
|
231 |
+
@p2d.setter
|
232 |
+
def p2d(self, value):
|
233 |
+
if value is not None:
|
234 |
+
assert isinstance(value, Point2D)
|
235 |
+
self._p2d = value
|
236 |
+
|
237 |
+
|
238 |
+
class CapturedDepth(CapturedContent):
|
239 |
+
def __init__(self, depth_path, crop_cam, pinhole_cam_before=None):
|
240 |
+
super(CapturedDepth, self).__init__()
|
241 |
+
if not depth_path.endswith('dummy'):
|
242 |
+
assert os.path.isfile(depth_path), 'file does not exist: {0}'.format(depth_path)
|
243 |
+
self.crop_cam = crop_cam
|
244 |
+
self._depth = None
|
245 |
+
self.depth_path = depth_path
|
246 |
+
self.pinhole_cam_before = pinhole_cam_before
|
247 |
+
|
248 |
+
def read_depth(self):
|
249 |
+
import tables
|
250 |
+
if self.depth_path.endswith('dummy'):
|
251 |
+
image_path = self.depth_path[:-5]
|
252 |
+
w, h = Image.open(image_path).size
|
253 |
+
_depth = np.zeros([h, w], dtype=np.float32)
|
254 |
+
elif self.depth_path.endswith('.h5'):
|
255 |
+
depth_h5 = tables.open_file(self.depth_path, mode='r')
|
256 |
+
_depth = np.array(depth_h5.root.depth)
|
257 |
+
depth_h5.close()
|
258 |
+
else:
|
259 |
+
raise ValueError
|
260 |
+
return _depth.astype(np.float32)
|
261 |
+
|
262 |
+
def read_depth_to_ram(self) -> int:
|
263 |
+
# raise NotImplementedError
|
264 |
+
assert self._depth is None
|
265 |
+
_depth = self.depth_map
|
266 |
+
self._depth = _depth
|
267 |
+
return self._depth.nbytes
|
268 |
+
|
269 |
+
@property
|
270 |
+
def depth_map(self):
|
271 |
+
if self._depth is not None:
|
272 |
+
_depth = self._depth
|
273 |
+
else:
|
274 |
+
_depth = self.read_depth()
|
275 |
+
if self.rotation != 0:
|
276 |
+
_depth = rotate_image(_depth, self.rotation, interpolation=cv2.INTER_NEAREST)
|
277 |
+
if _depth.shape != self.pinhole_cam_before.shape:
|
278 |
+
_depth = np.array(PIL.Image.fromarray(_depth).resize(self.pinhole_cam_before.shape[::-1], resample=PIL.Image.NEAREST))
|
279 |
+
assert _depth.shape[:2] == self.pinhole_cam_before.shape
|
280 |
+
if self.crop_cam == 'no_crop':
|
281 |
+
pass
|
282 |
+
elif self.crop_cam == 'crop_center':
|
283 |
+
_depth = crop_center_max(_depth)
|
284 |
+
elif self.crop_cam == 'crop_center_and_resize':
|
285 |
+
_depth = crop_center_max(_depth)
|
286 |
+
_depth = np.array(PIL.Image.fromarray(_depth).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.NEAREST))
|
287 |
+
elif isinstance(self.crop_cam, CropCamConfig):
|
288 |
+
assert _depth.shape[0] == self.crop_cam.orig_h
|
289 |
+
assert _depth.shape[1] == self.crop_cam.orig_w
|
290 |
+
_depth = _depth[self.crop_cam.y:self.crop_cam.y + self.crop_cam.h,
|
291 |
+
self.crop_cam.x:self.crop_cam.x + self.crop_cam.w, ]
|
292 |
+
_depth = np.array(PIL.Image.fromarray(_depth).resize((self.crop_cam.out_w, self.crop_cam.out_h), resample=PIL.Image.NEAREST))
|
293 |
+
assert _depth.shape[:2] == (self.crop_cam.out_h, self.crop_cam.out_w)
|
294 |
+
else:
|
295 |
+
raise ValueError()
|
296 |
+
assert (_depth >= 0).all()
|
297 |
+
return _depth
|
298 |
+
|
299 |
+
|
300 |
+
################ Pinhole Capture ################
|
301 |
+
class BasePinholeCapture():
|
302 |
+
def __init__(self, pinhole_cam, cam_pose, crop_cam):
|
303 |
+
self.crop_cam = crop_cam
|
304 |
+
self.cam_pose = cam_pose
|
305 |
+
# modify the camera instrinsics
|
306 |
+
self.pinhole_cam = crop_pinhole_camera(pinhole_cam, crop_cam)
|
307 |
+
self.pinhole_cam_before = pinhole_cam
|
308 |
+
|
309 |
+
def __str__(self):
|
310 |
+
string = 'pinhole camera: {0}\ncamera pose: {1}'.format(self.pinhole_cam, self.cam_pose)
|
311 |
+
return string
|
312 |
+
|
313 |
+
@property
|
314 |
+
def intrinsic_mat(self):
|
315 |
+
return self.pinhole_cam.intrinsic_mat
|
316 |
+
|
317 |
+
@property
|
318 |
+
def extrinsic_mat(self):
|
319 |
+
return self.cam_pose.extrinsic_mat
|
320 |
+
|
321 |
+
@property
|
322 |
+
def shape(self):
|
323 |
+
return self.pinhole_cam.shape
|
324 |
+
|
325 |
+
@property
|
326 |
+
def size(self):
|
327 |
+
return self.shape
|
328 |
+
|
329 |
+
@property
|
330 |
+
def mvp_mat(self):
|
331 |
+
'''
|
332 |
+
model-view-projection matrix (naming from opengl)
|
333 |
+
'''
|
334 |
+
return np.matmul(self.pinhole_cam.intrinsic_mat, self.cam_pose.world_to_camera_3x4)
|
335 |
+
|
336 |
+
|
337 |
+
class RGBPinholeCapture(BasePinholeCapture):
|
338 |
+
def __init__(self, img_path, pinhole_cam, cam_pose, crop_cam):
|
339 |
+
BasePinholeCapture.__init__(self, pinhole_cam, cam_pose, crop_cam)
|
340 |
+
self.captured_image = CapturedImage(img_path, crop_cam, self.pinhole_cam_before)
|
341 |
+
|
342 |
+
def read_image_to_ram(self) -> int:
|
343 |
+
return self.captured_image.read_image_to_ram()
|
344 |
+
|
345 |
+
@property
|
346 |
+
def img_path(self):
|
347 |
+
return self.captured_image.img_path
|
348 |
+
|
349 |
+
@property
|
350 |
+
def image(self):
|
351 |
+
_image = self.captured_image.image
|
352 |
+
assert _image.shape[0:2] == self.pinhole_cam.shape, 'image shape: {0}, pinhole camera: {1}'.format(_image.shape, self.pinhole_cam)
|
353 |
+
return _image
|
354 |
+
|
355 |
+
@property
|
356 |
+
def seq_id(self):
|
357 |
+
return os.path.dirname(self.captured_image.img_path)
|
358 |
+
|
359 |
+
@property
|
360 |
+
def p2d(self):
|
361 |
+
return self.captured_image.p2d
|
362 |
+
|
363 |
+
@p2d.setter
|
364 |
+
def p2d(self, value):
|
365 |
+
self.captured_image.p2d = value
|
366 |
+
|
367 |
+
|
368 |
+
class DepthPinholeCapture(BasePinholeCapture):
|
369 |
+
def __init__(self, depth_path, pinhole_cam, cam_pose, crop_cam):
|
370 |
+
BasePinholeCapture.__init__(self, pinhole_cam, cam_pose, crop_cam)
|
371 |
+
self.captured_depth = CapturedDepth(depth_path, crop_cam, self.pinhole_cam_before)
|
372 |
+
|
373 |
+
def read_depth_to_ram(self) -> int:
|
374 |
+
return self.captured_depth.read_depth_to_ram()
|
375 |
+
|
376 |
+
@property
|
377 |
+
def depth_path(self):
|
378 |
+
return self.captured_depth.depth_path
|
379 |
+
|
380 |
+
@property
|
381 |
+
def depth_map(self):
|
382 |
+
_depth = self.captured_depth.depth_map
|
383 |
+
# if self.pinhole_cam.shape != _depth.shape:
|
384 |
+
# _depth = misc.imresize(_depth, self.pinhole_cam.shape, interp='nearest', mode='F')
|
385 |
+
assert (_depth >= 0).all()
|
386 |
+
return _depth
|
387 |
+
|
388 |
+
@property
|
389 |
+
def point_cloud_world(self):
|
390 |
+
return self.get_point_cloud_world_from_depth(feat_map=None)
|
391 |
+
|
392 |
+
def get_point_cloud_world_from_depth(self, feat_map=None):
|
393 |
+
_pcd = pcd_projector.PointCloudProjector.img_2d_to_pcd_3d_np(self.depth_map, self.pinhole_cam.intrinsic_mat, img=feat_map, motion=self.cam_pose.camera_to_world).astype(constants.DEFAULT_PRECISION)
|
394 |
+
return _pcd
|
395 |
+
|
396 |
+
|
397 |
+
class RGBDPinholeCapture(RGBPinholeCapture, DepthPinholeCapture):
|
398 |
+
def __init__(self, img_path, depth_path, pinhole_cam, cam_pose, crop_cam):
|
399 |
+
RGBPinholeCapture.__init__(self, img_path, pinhole_cam, cam_pose, crop_cam)
|
400 |
+
DepthPinholeCapture.__init__(self, depth_path, pinhole_cam, cam_pose, crop_cam)
|
401 |
+
|
402 |
+
@property
|
403 |
+
def point_cloud_w_rgb_world(self):
|
404 |
+
return self.get_point_cloud_world_from_depth(feat_map=self.image)
|
405 |
+
|
406 |
+
|
407 |
+
def rotate_capture(cap, rot):
|
408 |
+
if rot == 0:
|
409 |
+
return copy.deepcopy(cap)
|
410 |
+
else:
|
411 |
+
rot_pose = rotate_camera_pose(cap.cam_pose, rot)
|
412 |
+
rot_cap = copy.deepcopy(cap)
|
413 |
+
rot_cap.cam_pose = rot_pose
|
414 |
+
if hasattr(rot_cap, 'captured_image'):
|
415 |
+
rot_cap.captured_image.rotation = rot
|
416 |
+
if hasattr(rot_cap, 'captured_depth'):
|
417 |
+
rot_cap.captured_depth.rotation = rot
|
418 |
+
return rot_cap
|
419 |
+
|
420 |
+
|
421 |
+
def crop_capture(cap, crop_cam):
|
422 |
+
if isinstance(cap, RGBDPinholeCapture):
|
423 |
+
cropped_cap = RGBDPinholeCapture(cap.img_path, cap.depth_path, cap.pinhole_cam, cap.cam_pose, crop_cam)
|
424 |
+
elif isinstance(cap, RGBPinholeCapture):
|
425 |
+
cropped_cap = RGBPinholeCapture(cap.img_path, cap.pinhole_cam, cap.cam_pose, crop_cam)
|
426 |
+
else:
|
427 |
+
raise ValueError
|
428 |
+
if hasattr(cropped_cap, 'captured_image'):
|
429 |
+
cropped_cap.captured_image.rotation = cap.captured_image.rotation
|
430 |
+
if hasattr(cropped_cap, 'captured_depth'):
|
431 |
+
cropped_cap.captured_depth.rotation = cap.captured_depth.rotation
|
432 |
+
return cropped_cap
|
third_party/COTR/COTR/cameras/pinhole_camera.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Static pinhole camera
|
3 |
+
"""
|
4 |
+
|
5 |
+
import copy
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from COTR.utils import constants
|
10 |
+
from COTR.utils.constants import MAX_SIZE
|
11 |
+
from COTR.utils.utils import CropCamConfig
|
12 |
+
|
13 |
+
|
14 |
+
class PinholeCamera():
|
15 |
+
def __init__(self, width, height, fx, fy, cx, cy):
|
16 |
+
self.width = int(width)
|
17 |
+
self.height = int(height)
|
18 |
+
self.fx = fx
|
19 |
+
self.fy = fy
|
20 |
+
self.cx = cx
|
21 |
+
self.cy = cy
|
22 |
+
|
23 |
+
def __str__(self):
|
24 |
+
string = 'width: {0}, height: {1}, fx: {2}, fy: {3}, cx: {4}, cy: {5}'.format(self.width, self.height, self.fx, self.fy, self.cx, self.cy)
|
25 |
+
return string
|
26 |
+
|
27 |
+
@property
|
28 |
+
def shape(self):
|
29 |
+
return (self.height, self.width)
|
30 |
+
|
31 |
+
@property
|
32 |
+
def intrinsic_mat(self):
|
33 |
+
mat = np.array([[self.fx, 0.0, self.cx],
|
34 |
+
[0.0, self.fy, self.cy],
|
35 |
+
[0.0, 0.0, 1.0]], dtype=constants.DEFAULT_PRECISION)
|
36 |
+
return mat
|
37 |
+
|
38 |
+
|
39 |
+
def rotate_pinhole_camera(cam, rot):
|
40 |
+
assert 0, 'TODO: Camera should stay the same while rotation'
|
41 |
+
assert rot in [0, 90, 180, 270], 'only support 0/90/180/270 degrees rotation'
|
42 |
+
if rot in [0, 180]:
|
43 |
+
return copy.deepcopy(cam)
|
44 |
+
elif rot in [90, 270]:
|
45 |
+
return PinholeCamera(width=cam.height, height=cam.width, fx=cam.fy, fy=cam.fx, cx=cam.cy, cy=cam.cx)
|
46 |
+
else:
|
47 |
+
raise NotImplementedError
|
48 |
+
|
49 |
+
|
50 |
+
def crop_pinhole_camera(pinhole_cam, crop_cam):
|
51 |
+
if crop_cam == 'no_crop':
|
52 |
+
cropped_pinhole_cam = pinhole_cam
|
53 |
+
elif crop_cam == 'crop_center':
|
54 |
+
_h = _w = min(*pinhole_cam.shape)
|
55 |
+
_cx = _cy = _h / 2
|
56 |
+
cropped_pinhole_cam = PinholeCamera(_w, _h, pinhole_cam.fx, pinhole_cam.fy, _cx, _cy)
|
57 |
+
elif crop_cam == 'crop_center_and_resize':
|
58 |
+
_h = _w = MAX_SIZE
|
59 |
+
_cx = _cy = MAX_SIZE / 2
|
60 |
+
scale = MAX_SIZE / min(*pinhole_cam.shape)
|
61 |
+
cropped_pinhole_cam = PinholeCamera(_w, _h, pinhole_cam.fx * scale, pinhole_cam.fy * scale, _cx, _cy)
|
62 |
+
elif isinstance(crop_cam, CropCamConfig):
|
63 |
+
scale = crop_cam.out_h / crop_cam.h
|
64 |
+
cropped_pinhole_cam = PinholeCamera(crop_cam.out_w,
|
65 |
+
crop_cam.out_h,
|
66 |
+
pinhole_cam.fx * scale,
|
67 |
+
pinhole_cam.fy * scale,
|
68 |
+
(pinhole_cam.cx - crop_cam.x) * scale,
|
69 |
+
(pinhole_cam.cy - crop_cam.y) * scale
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
raise ValueError
|
73 |
+
return cropped_pinhole_cam
|
third_party/COTR/COTR/datasets/colmap_helper.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
assert sys.version_info >= (3, 7), 'ordered dict is required'
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
from collections import namedtuple
|
6 |
+
import json
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from COTR.utils import debug_utils
|
12 |
+
from COTR.cameras.pinhole_camera import PinholeCamera
|
13 |
+
from COTR.cameras.camera_pose import CameraPose
|
14 |
+
from COTR.cameras.capture import RGBPinholeCapture, RGBDPinholeCapture
|
15 |
+
from COTR.cameras import capture
|
16 |
+
from COTR.transformations import transformations
|
17 |
+
from COTR.transformations.transform_basics import Translation, Rotation
|
18 |
+
from COTR.sfm_scenes import sfm_scenes
|
19 |
+
from COTR.global_configs import dataset_config
|
20 |
+
from COTR.utils.utils import Point2D, Point3D
|
21 |
+
|
22 |
+
ImageMeta = namedtuple('ImageMeta', ['image_id', 'r', 't', 'camera_id', 'image_path', 'point3d_id', 'p2d'])
|
23 |
+
COVISIBILITY_CHECK = False
|
24 |
+
LOAD_PCD = False
|
25 |
+
|
26 |
+
|
27 |
+
class ColmapAsciiReader():
|
28 |
+
def __init__(self):
|
29 |
+
pass
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def read_sfm_scene(cls, scene_dir, images_dir, crop_cam):
|
33 |
+
point_cloud_path = os.path.join(scene_dir, 'points3D.txt')
|
34 |
+
cameras_path = os.path.join(scene_dir, 'cameras.txt')
|
35 |
+
images_path = os.path.join(scene_dir, 'images.txt')
|
36 |
+
captures = cls.read_captures(images_path, cameras_path, images_dir, crop_cam)
|
37 |
+
if LOAD_PCD:
|
38 |
+
point_cloud = cls.read_point_cloud(point_cloud_path)
|
39 |
+
else:
|
40 |
+
point_cloud = None
|
41 |
+
sfm_scene = sfm_scenes.SfmScene(captures, point_cloud)
|
42 |
+
return sfm_scene
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def read_point_cloud(points_txt_path):
|
46 |
+
with open(points_txt_path, "r") as fid:
|
47 |
+
line = fid.readline()
|
48 |
+
assert line == '# 3D point list with one line of data per point:\n'
|
49 |
+
line = fid.readline()
|
50 |
+
assert line == '# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n'
|
51 |
+
line = fid.readline()
|
52 |
+
assert re.search('^# Number of points: \d+, mean track length: [-+]?\d*\.\d+|\d+\n$', line)
|
53 |
+
num_points, mean_track_length = re.findall(r"[-+]?\d*\.\d+|\d+", line)
|
54 |
+
num_points = int(num_points)
|
55 |
+
mean_track_length = float(mean_track_length)
|
56 |
+
|
57 |
+
xyz = np.zeros((num_points, 3), dtype=np.float32)
|
58 |
+
rgb = np.zeros((num_points, 3), dtype=np.float32)
|
59 |
+
if COVISIBILITY_CHECK:
|
60 |
+
point_meta = {}
|
61 |
+
|
62 |
+
for i in tqdm(range(num_points), desc='reading point cloud'):
|
63 |
+
elems = fid.readline().split()
|
64 |
+
xyz[i] = list(map(float, elems[1:4]))
|
65 |
+
rgb[i] = list(map(int, elems[4:7]))
|
66 |
+
if COVISIBILITY_CHECK:
|
67 |
+
point_id = int(elems[0])
|
68 |
+
image_ids = np.array(tuple(map(int, elems[8::2])))
|
69 |
+
point_meta[point_id] = Point3D(id=point_id,
|
70 |
+
arr_idx=i,
|
71 |
+
image_ids=image_ids)
|
72 |
+
pcd = np.concatenate([xyz, rgb], axis=1)
|
73 |
+
if COVISIBILITY_CHECK:
|
74 |
+
return pcd, point_meta
|
75 |
+
else:
|
76 |
+
return pcd
|
77 |
+
|
78 |
+
@classmethod
|
79 |
+
def read_captures(cls, images_txt_path, cameras_txt_path, images_dir, crop_cam):
|
80 |
+
captures = []
|
81 |
+
cameras = cls.read_cameras(cameras_txt_path)
|
82 |
+
images_meta = cls.read_images_meta(images_txt_path, images_dir)
|
83 |
+
for key in images_meta.keys():
|
84 |
+
cur_cam_id = images_meta[key].camera_id
|
85 |
+
cur_cam = cameras[cur_cam_id]
|
86 |
+
cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r)
|
87 |
+
cur_image_path = images_meta[key].image_path
|
88 |
+
cap = RGBPinholeCapture(cur_image_path, cur_cam, cur_camera_pose, crop_cam)
|
89 |
+
captures.append(cap)
|
90 |
+
return captures
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def read_cameras(cls, cameras_txt_path):
|
94 |
+
cameras = {}
|
95 |
+
with open(cameras_txt_path, "r") as fid:
|
96 |
+
line = fid.readline()
|
97 |
+
assert line == '# Camera list with one line of data per camera:\n'
|
98 |
+
line = fid.readline()
|
99 |
+
assert line == '# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n'
|
100 |
+
line = fid.readline()
|
101 |
+
assert re.search('^# Number of cameras: \d+\n$', line)
|
102 |
+
num_cams = int(re.findall(r"[-+]?\d*\.\d+|\d+", line)[0])
|
103 |
+
|
104 |
+
for _ in tqdm(range(num_cams), desc='reading cameras'):
|
105 |
+
elems = fid.readline().split()
|
106 |
+
camera_id = int(elems[0])
|
107 |
+
camera_type = elems[1]
|
108 |
+
if camera_type == "PINHOLE":
|
109 |
+
width, height, focal_length_x, focal_length_y, cx, cy = list(map(float, elems[2:8]))
|
110 |
+
else:
|
111 |
+
raise ValueError('Please rectify the 3D model to pinhole camera.')
|
112 |
+
cur_cam = PinholeCamera(width, height, focal_length_x, focal_length_y, cx, cy)
|
113 |
+
assert camera_id not in cameras
|
114 |
+
cameras[camera_id] = cur_cam
|
115 |
+
return cameras
|
116 |
+
|
117 |
+
@classmethod
|
118 |
+
def read_images_meta(cls, images_txt_path, images_dir):
|
119 |
+
images_meta = {}
|
120 |
+
with open(images_txt_path, "r") as fid:
|
121 |
+
line = fid.readline()
|
122 |
+
assert line == '# Image list with two lines of data per image:\n'
|
123 |
+
line = fid.readline()
|
124 |
+
assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n'
|
125 |
+
line = fid.readline()
|
126 |
+
assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n'
|
127 |
+
line = fid.readline()
|
128 |
+
assert re.search('^# Number of images: \d+, mean observations per image: [-+]?\d*\.\d+|\d+\n$', line)
|
129 |
+
num_images, mean_ob_per_img = re.findall(r"[-+]?\d*\.\d+|\d+", line)
|
130 |
+
num_images = int(num_images)
|
131 |
+
mean_ob_per_img = float(mean_ob_per_img)
|
132 |
+
|
133 |
+
for _ in tqdm(range(num_images), desc='reading images meta'):
|
134 |
+
elems = fid.readline().split()
|
135 |
+
assert len(elems) == 10
|
136 |
+
|
137 |
+
image_path = os.path.join(images_dir, elems[9])
|
138 |
+
assert os.path.isfile(image_path)
|
139 |
+
image_id = int(elems[0])
|
140 |
+
qw, qx, qy, qz, tx, ty, tz = list(map(float, elems[1:8]))
|
141 |
+
t = Translation(np.array([tx, ty, tz], dtype=np.float32))
|
142 |
+
r = Rotation(np.array([qw, qx, qy, qz], dtype=np.float32))
|
143 |
+
camera_id = int(elems[8])
|
144 |
+
assert image_id not in images_meta
|
145 |
+
|
146 |
+
line = fid.readline()
|
147 |
+
if COVISIBILITY_CHECK:
|
148 |
+
elems = line.split()
|
149 |
+
elems = list(map(float, elems))
|
150 |
+
elems = np.array(elems).reshape(-1, 3)
|
151 |
+
point3d_id = set(elems[elems[:, 2] != -1][:, 2].astype(np.int))
|
152 |
+
point3d_id = np.sort(np.array(list(point3d_id)))
|
153 |
+
xyi = elems[elems[:, 2] != -1]
|
154 |
+
xy = xyi[:, :2]
|
155 |
+
idx = xyi[:, 2].astype(np.int)
|
156 |
+
p2d = Point2D(idx, xy)
|
157 |
+
else:
|
158 |
+
point3d_id = None
|
159 |
+
p2d = None
|
160 |
+
|
161 |
+
images_meta[image_id] = ImageMeta(image_id, r, t, camera_id, image_path, point3d_id, p2d)
|
162 |
+
return images_meta
|
163 |
+
|
164 |
+
|
165 |
+
class ColmapWithDepthAsciiReader(ColmapAsciiReader):
|
166 |
+
'''
|
167 |
+
Not all images have usable depth estimate from colmap.
|
168 |
+
A valid list is needed.
|
169 |
+
'''
|
170 |
+
|
171 |
+
@classmethod
|
172 |
+
def read_sfm_scene(cls, scene_dir, images_dir, depth_dir, crop_cam):
|
173 |
+
point_cloud_path = os.path.join(scene_dir, 'points3D.txt')
|
174 |
+
cameras_path = os.path.join(scene_dir, 'cameras.txt')
|
175 |
+
images_path = os.path.join(scene_dir, 'images.txt')
|
176 |
+
captures = cls.read_captures(images_path, cameras_path, images_dir, depth_dir, crop_cam)
|
177 |
+
if LOAD_PCD:
|
178 |
+
point_cloud = cls.read_point_cloud(point_cloud_path)
|
179 |
+
else:
|
180 |
+
point_cloud = None
|
181 |
+
sfm_scene = sfm_scenes.SfmScene(captures, point_cloud)
|
182 |
+
return sfm_scene
|
183 |
+
|
184 |
+
@classmethod
|
185 |
+
def read_sfm_scene_given_valid_list_path(cls, scene_dir, images_dir, depth_dir, valid_list_json_path, crop_cam):
|
186 |
+
point_cloud_path = os.path.join(scene_dir, 'points3D.txt')
|
187 |
+
cameras_path = os.path.join(scene_dir, 'cameras.txt')
|
188 |
+
images_path = os.path.join(scene_dir, 'images.txt')
|
189 |
+
valid_list = cls.read_valid_list(valid_list_json_path)
|
190 |
+
captures = cls.read_captures_with_depth_given_valid_list(images_path, cameras_path, images_dir, depth_dir, valid_list, crop_cam)
|
191 |
+
if LOAD_PCD:
|
192 |
+
point_cloud = cls.read_point_cloud(point_cloud_path)
|
193 |
+
else:
|
194 |
+
point_cloud = None
|
195 |
+
sfm_scene = sfm_scenes.SfmScene(captures, point_cloud)
|
196 |
+
return sfm_scene
|
197 |
+
|
198 |
+
@classmethod
|
199 |
+
def read_captures(cls, images_txt_path, cameras_txt_path, images_dir, depth_dir, crop_cam):
|
200 |
+
captures = []
|
201 |
+
cameras = cls.read_cameras(cameras_txt_path)
|
202 |
+
images_meta = cls.read_images_meta(images_txt_path, images_dir)
|
203 |
+
for key in images_meta.keys():
|
204 |
+
cur_cam_id = images_meta[key].camera_id
|
205 |
+
cur_cam = cameras[cur_cam_id]
|
206 |
+
cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r)
|
207 |
+
cur_image_path = images_meta[key].image_path
|
208 |
+
try:
|
209 |
+
cur_depth_path = cls.image_path_2_depth_path(cur_image_path[len(images_dir) + 1:], depth_dir)
|
210 |
+
except:
|
211 |
+
print('{0} does not have depth at {1}'.format(cur_image_path, depth_dir))
|
212 |
+
# TODO
|
213 |
+
# continue
|
214 |
+
# exec(debug_utils.embed_breakpoint())
|
215 |
+
cur_depth_path = f'{cur_image_path}dummy'
|
216 |
+
|
217 |
+
cap = RGBDPinholeCapture(cur_image_path, cur_depth_path, cur_cam, cur_camera_pose, crop_cam)
|
218 |
+
cap.point3d_id = images_meta[key].point3d_id
|
219 |
+
cap.p2d = images_meta[key].p2d
|
220 |
+
cap.image_id = key
|
221 |
+
captures.append(cap)
|
222 |
+
return captures
|
223 |
+
|
224 |
+
@classmethod
|
225 |
+
def read_captures_with_depth_given_valid_list(cls, images_txt_path, cameras_txt_path, images_dir, depth_dir, valid_list, crop_cam):
|
226 |
+
captures = []
|
227 |
+
cameras = cls.read_cameras(cameras_txt_path)
|
228 |
+
images_meta = cls.read_images_meta_given_valid_list(images_txt_path, images_dir, valid_list)
|
229 |
+
for key in images_meta.keys():
|
230 |
+
cur_cam_id = images_meta[key].camera_id
|
231 |
+
cur_cam = cameras[cur_cam_id]
|
232 |
+
cur_camera_pose = CameraPose(images_meta[key].t, images_meta[key].r)
|
233 |
+
cur_image_path = images_meta[key].image_path
|
234 |
+
try:
|
235 |
+
cur_depth_path = cls.image_path_2_depth_path(cur_image_path, depth_dir)
|
236 |
+
except:
|
237 |
+
print('{0} does not have depth at {1}'.format(cur_image_path, depth_dir))
|
238 |
+
continue
|
239 |
+
cap = RGBDPinholeCapture(cur_image_path, cur_depth_path, cur_cam, cur_camera_pose, crop_cam)
|
240 |
+
cap.point3d_id = images_meta[key].point3d_id
|
241 |
+
cap.p2d = images_meta[key].p2d
|
242 |
+
cap.image_id = key
|
243 |
+
captures.append(cap)
|
244 |
+
return captures
|
245 |
+
|
246 |
+
@classmethod
|
247 |
+
def read_images_meta_given_valid_list(cls, images_txt_path, images_dir, valid_list):
|
248 |
+
images_meta = {}
|
249 |
+
with open(images_txt_path, "r") as fid:
|
250 |
+
line = fid.readline()
|
251 |
+
assert line == '# Image list with two lines of data per image:\n'
|
252 |
+
line = fid.readline()
|
253 |
+
assert line == '# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n'
|
254 |
+
line = fid.readline()
|
255 |
+
assert line == '# POINTS2D[] as (X, Y, POINT3D_ID)\n'
|
256 |
+
line = fid.readline()
|
257 |
+
assert re.search('^# Number of images: \d+, mean observations per image:[-+]?\d*\.\d+|\d+\n$', line), line
|
258 |
+
num_images, mean_ob_per_img = re.findall(r"[-+]?\d*\.\d+|\d+", line)
|
259 |
+
num_images = int(num_images)
|
260 |
+
mean_ob_per_img = float(mean_ob_per_img)
|
261 |
+
|
262 |
+
for _ in tqdm(range(num_images), desc='reading images meta'):
|
263 |
+
elems = fid.readline().split()
|
264 |
+
assert len(elems) == 10
|
265 |
+
line = fid.readline()
|
266 |
+
image_path = os.path.join(images_dir, elems[9])
|
267 |
+
prefix = os.path.abspath(os.path.join(image_path, '../../../../')) + '/'
|
268 |
+
rel_image_path = image_path.replace(prefix, '')
|
269 |
+
if rel_image_path not in valid_list:
|
270 |
+
continue
|
271 |
+
assert os.path.isfile(image_path), '{0} is not existing'.format(image_path)
|
272 |
+
image_id = int(elems[0])
|
273 |
+
qw, qx, qy, qz, tx, ty, tz = list(map(float, elems[1:8]))
|
274 |
+
t = Translation(np.array([tx, ty, tz], dtype=np.float32))
|
275 |
+
r = Rotation(np.array([qw, qx, qy, qz], dtype=np.float32))
|
276 |
+
camera_id = int(elems[8])
|
277 |
+
assert image_id not in images_meta
|
278 |
+
|
279 |
+
if COVISIBILITY_CHECK:
|
280 |
+
elems = line.split()
|
281 |
+
elems = list(map(float, elems))
|
282 |
+
elems = np.array(elems).reshape(-1, 3)
|
283 |
+
point3d_id = set(elems[elems[:, 2] != -1][:, 2].astype(np.int))
|
284 |
+
point3d_id = np.sort(np.array(list(point3d_id)))
|
285 |
+
xyi = elems[elems[:, 2] != -1]
|
286 |
+
xy = xyi[:, :2]
|
287 |
+
idx = xyi[:, 2].astype(np.int)
|
288 |
+
p2d = Point2D(idx, xy)
|
289 |
+
else:
|
290 |
+
point3d_id = None
|
291 |
+
p2d = None
|
292 |
+
images_meta[image_id] = ImageMeta(image_id, r, t, camera_id, image_path, point3d_id, p2d)
|
293 |
+
return images_meta
|
294 |
+
|
295 |
+
@classmethod
|
296 |
+
def read_valid_list(cls, valid_list_json_path):
|
297 |
+
assert os.path.isfile(valid_list_json_path), valid_list_json_path
|
298 |
+
with open(valid_list_json_path, 'r') as f:
|
299 |
+
valid_list = json.load(f)
|
300 |
+
assert len(valid_list) == len(set(valid_list))
|
301 |
+
return set(valid_list)
|
302 |
+
|
303 |
+
@classmethod
|
304 |
+
def image_path_2_depth_path(cls, image_path, depth_dir):
|
305 |
+
depth_file = os.path.splitext(os.path.basename(image_path))[0] + '.h5'
|
306 |
+
depth_path = os.path.join(depth_dir, depth_file)
|
307 |
+
if not os.path.isfile(depth_path):
|
308 |
+
# depth_file = image_path + '.photometric.bin'
|
309 |
+
depth_file = image_path + '.geometric.bin'
|
310 |
+
depth_path = os.path.join(depth_dir, depth_file)
|
311 |
+
assert os.path.isfile(depth_path), '{0} is not file'.format(depth_path)
|
312 |
+
return depth_path
|
third_party/COTR/COTR/datasets/cotr_dataset.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
COTR dataset
|
3 |
+
'''
|
4 |
+
|
5 |
+
import random
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from torchvision.transforms import functional as tvtf
|
10 |
+
from torch.utils import data
|
11 |
+
|
12 |
+
from COTR.datasets import megadepth_dataset
|
13 |
+
from COTR.utils import debug_utils, utils, constants
|
14 |
+
from COTR.projector import pcd_projector
|
15 |
+
from COTR.cameras import capture
|
16 |
+
from COTR.utils.utils import CropCamConfig
|
17 |
+
from COTR.inference import inference_helper
|
18 |
+
from COTR.inference.inference_helper import two_images_side_by_side
|
19 |
+
|
20 |
+
|
21 |
+
class COTRDataset(data.Dataset):
|
22 |
+
def __init__(self, opt, dataset_type: str):
|
23 |
+
assert dataset_type in ['train', 'val', 'test']
|
24 |
+
assert len(opt.scenes_name_list) > 0
|
25 |
+
self.opt = opt
|
26 |
+
self.dataset_type = dataset_type
|
27 |
+
self.sfm_dataset = megadepth_dataset.MegadepthDataset(opt, dataset_type)
|
28 |
+
|
29 |
+
self.kp_pool = opt.kp_pool
|
30 |
+
self.num_kp = opt.num_kp
|
31 |
+
self.bidirectional = opt.bidirectional
|
32 |
+
self.need_rotation = opt.need_rotation
|
33 |
+
self.max_rotation = opt.max_rotation
|
34 |
+
self.rotation_chance = opt.rotation_chance
|
35 |
+
|
36 |
+
def _trim_corrs(self, in_corrs):
|
37 |
+
length = in_corrs.shape[0]
|
38 |
+
if length >= self.num_kp:
|
39 |
+
mask = np.random.choice(length, self.num_kp)
|
40 |
+
return in_corrs[mask]
|
41 |
+
else:
|
42 |
+
mask = np.random.choice(length, self.num_kp - length)
|
43 |
+
return np.concatenate([in_corrs, in_corrs[mask]], axis=0)
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
if self.dataset_type == 'val':
|
47 |
+
return min(1000, self.sfm_dataset.num_queries)
|
48 |
+
else:
|
49 |
+
return self.sfm_dataset.num_queries
|
50 |
+
|
51 |
+
def augment_with_rotation(self, query_cap, nn_cap):
|
52 |
+
if random.random() < self.rotation_chance:
|
53 |
+
theta = np.random.uniform(low=-1, high=1) * self.max_rotation
|
54 |
+
query_cap = capture.rotate_capture(query_cap, theta)
|
55 |
+
if random.random() < self.rotation_chance:
|
56 |
+
theta = np.random.uniform(low=-1, high=1) * self.max_rotation
|
57 |
+
nn_cap = capture.rotate_capture(nn_cap, theta)
|
58 |
+
return query_cap, nn_cap
|
59 |
+
|
60 |
+
def __getitem__(self, index):
|
61 |
+
assert self.opt.k_size == 1
|
62 |
+
query_cap, nn_caps = self.sfm_dataset.get_query_with_knn(index)
|
63 |
+
nn_cap = nn_caps[0]
|
64 |
+
|
65 |
+
if self.need_rotation:
|
66 |
+
query_cap, nn_cap = self.augment_with_rotation(query_cap, nn_cap)
|
67 |
+
|
68 |
+
nn_keypoints_y, nn_keypoints_x = np.where(nn_cap.depth_map > 0)
|
69 |
+
nn_keypoints_y = nn_keypoints_y[..., None]
|
70 |
+
nn_keypoints_x = nn_keypoints_x[..., None]
|
71 |
+
nn_keypoints_z = nn_cap.depth_map[np.floor(nn_keypoints_y).astype('int'), np.floor(nn_keypoints_x).astype('int')]
|
72 |
+
nn_keypoints_xy = np.concatenate([nn_keypoints_x, nn_keypoints_y], axis=1)
|
73 |
+
nn_keypoints_3d_world, valid_index_1 = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(nn_keypoints_xy, nn_keypoints_z, nn_cap.pinhole_cam.intrinsic_mat, motion=nn_cap.cam_pose.camera_to_world, return_index=True)
|
74 |
+
|
75 |
+
query_keypoints_xyz, valid_index_2 = pcd_projector.PointCloudProjector.pcd_3d_to_pcd_2d_np(
|
76 |
+
nn_keypoints_3d_world,
|
77 |
+
query_cap.pinhole_cam.intrinsic_mat,
|
78 |
+
query_cap.cam_pose.world_to_camera[0:3, :],
|
79 |
+
query_cap.image.shape[:2],
|
80 |
+
keep_z=True,
|
81 |
+
crop=True,
|
82 |
+
filter_neg=True,
|
83 |
+
norm_coord=False,
|
84 |
+
return_index=True,
|
85 |
+
)
|
86 |
+
query_keypoints_xy = query_keypoints_xyz[:, 0:2]
|
87 |
+
query_keypoints_z_proj = query_keypoints_xyz[:, 2:3]
|
88 |
+
query_keypoints_z = query_cap.depth_map[np.floor(query_keypoints_xy[:, 1:2]).astype('int'), np.floor(query_keypoints_xy[:, 0:1]).astype('int')]
|
89 |
+
mask = (abs(query_keypoints_z - query_keypoints_z_proj) < 0.5)[:, 0]
|
90 |
+
query_keypoints_xy = query_keypoints_xy[mask]
|
91 |
+
|
92 |
+
if query_keypoints_xy.shape[0] < self.num_kp:
|
93 |
+
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
94 |
+
|
95 |
+
nn_keypoints_xy = nn_keypoints_xy[valid_index_1][valid_index_2][mask]
|
96 |
+
assert nn_keypoints_xy.shape == query_keypoints_xy.shape
|
97 |
+
corrs = np.concatenate([query_keypoints_xy, nn_keypoints_xy], axis=1)
|
98 |
+
corrs = self._trim_corrs(corrs)
|
99 |
+
# flip augmentation
|
100 |
+
if np.random.uniform() < 0.5:
|
101 |
+
corrs[:, 0] = constants.MAX_SIZE - 1 - corrs[:, 0]
|
102 |
+
corrs[:, 2] = constants.MAX_SIZE - 1 - corrs[:, 2]
|
103 |
+
sbs_img = two_images_side_by_side(np.fliplr(query_cap.image), np.fliplr(nn_cap.image))
|
104 |
+
else:
|
105 |
+
sbs_img = two_images_side_by_side(query_cap.image, nn_cap.image)
|
106 |
+
corrs[:, 2] += constants.MAX_SIZE
|
107 |
+
corrs /= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE])
|
108 |
+
assert (0.0 <= corrs[:, 0]).all() and (corrs[:, 0] <= 0.5).all()
|
109 |
+
assert (0.0 <= corrs[:, 1]).all() and (corrs[:, 1] <= 1.0).all()
|
110 |
+
assert (0.5 <= corrs[:, 2]).all() and (corrs[:, 2] <= 1.0).all()
|
111 |
+
assert (0.0 <= corrs[:, 3]).all() and (corrs[:, 3] <= 1.0).all()
|
112 |
+
out = {
|
113 |
+
'image': tvtf.normalize(tvtf.to_tensor(sbs_img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
114 |
+
'corrs': torch.from_numpy(corrs).float(),
|
115 |
+
}
|
116 |
+
if self.bidirectional:
|
117 |
+
out['queries'] = torch.from_numpy(np.concatenate([corrs[:, :2], corrs[:, 2:]], axis=0)).float()
|
118 |
+
out['targets'] = torch.from_numpy(np.concatenate([corrs[:, 2:], corrs[:, :2]], axis=0)).float()
|
119 |
+
else:
|
120 |
+
out['queries'] = torch.from_numpy(corrs[:, :2]).float()
|
121 |
+
out['targets'] = torch.from_numpy(corrs[:, 2:]).float()
|
122 |
+
return out
|
123 |
+
|
124 |
+
|
125 |
+
class COTRZoomDataset(COTRDataset):
|
126 |
+
def __init__(self, opt, dataset_type: str):
|
127 |
+
assert opt.crop_cam in ['no_crop', 'crop_center']
|
128 |
+
assert opt.use_ram == False
|
129 |
+
super().__init__(opt, dataset_type)
|
130 |
+
self.zoom_start = opt.zoom_start
|
131 |
+
self.zoom_end = opt.zoom_end
|
132 |
+
self.zoom_levels = opt.zoom_levels
|
133 |
+
self.zoom_jitter = opt.zoom_jitter
|
134 |
+
self.zooms = np.logspace(np.log10(opt.zoom_start),
|
135 |
+
np.log10(opt.zoom_end),
|
136 |
+
num=opt.zoom_levels)
|
137 |
+
|
138 |
+
def get_corrs(self, from_cap, to_cap, reduced_size=None):
|
139 |
+
from_y, from_x = np.where(from_cap.depth_map > 0)
|
140 |
+
from_y, from_x = from_y[..., None], from_x[..., None]
|
141 |
+
if reduced_size is not None:
|
142 |
+
filter_idx = np.random.choice(from_y.shape[0], reduced_size, replace=False)
|
143 |
+
from_y, from_x = from_y[filter_idx], from_x[filter_idx]
|
144 |
+
from_z = from_cap.depth_map[np.floor(from_y).astype('int'), np.floor(from_x).astype('int')]
|
145 |
+
from_xy = np.concatenate([from_x, from_y], axis=1)
|
146 |
+
from_3d_world, valid_index_1 = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(from_xy, from_z, from_cap.pinhole_cam.intrinsic_mat, motion=from_cap.cam_pose.camera_to_world, return_index=True)
|
147 |
+
|
148 |
+
to_xyz, valid_index_2 = pcd_projector.PointCloudProjector.pcd_3d_to_pcd_2d_np(
|
149 |
+
from_3d_world,
|
150 |
+
to_cap.pinhole_cam.intrinsic_mat,
|
151 |
+
to_cap.cam_pose.world_to_camera[0:3, :],
|
152 |
+
to_cap.image.shape[:2],
|
153 |
+
keep_z=True,
|
154 |
+
crop=True,
|
155 |
+
filter_neg=True,
|
156 |
+
norm_coord=False,
|
157 |
+
return_index=True,
|
158 |
+
)
|
159 |
+
|
160 |
+
to_xy = to_xyz[:, 0:2]
|
161 |
+
to_z_proj = to_xyz[:, 2:3]
|
162 |
+
to_z = to_cap.depth_map[np.floor(to_xy[:, 1:2]).astype('int'), np.floor(to_xy[:, 0:1]).astype('int')]
|
163 |
+
mask = (abs(to_z - to_z_proj) < 0.5)[:, 0]
|
164 |
+
if mask.sum() > 0:
|
165 |
+
return np.concatenate([from_xy[valid_index_1][valid_index_2][mask], to_xy[mask]], axis=1)
|
166 |
+
else:
|
167 |
+
return None
|
168 |
+
|
169 |
+
def get_seed_corr(self, from_cap, to_cap, max_try=100):
|
170 |
+
seed_corr = self.get_corrs(from_cap, to_cap, reduced_size=max_try)
|
171 |
+
if seed_corr is None:
|
172 |
+
return None
|
173 |
+
shuffle = np.random.permutation(seed_corr.shape[0])
|
174 |
+
seed_corr = np.take(seed_corr, shuffle, axis=0)
|
175 |
+
return seed_corr[0]
|
176 |
+
|
177 |
+
def get_zoomed_cap(self, cap, pos, scale, jitter):
|
178 |
+
patch = inference_helper.get_patch_centered_at(cap.image, pos, scale=scale, return_content=False)
|
179 |
+
patch = inference_helper.get_patch_centered_at(cap.image,
|
180 |
+
pos + np.array([patch.w, patch.h]) * np.random.uniform(-jitter, jitter, 2),
|
181 |
+
scale=scale,
|
182 |
+
return_content=False)
|
183 |
+
zoom_config = CropCamConfig(x=patch.x,
|
184 |
+
y=patch.y,
|
185 |
+
w=patch.w,
|
186 |
+
h=patch.h,
|
187 |
+
out_w=constants.MAX_SIZE,
|
188 |
+
out_h=constants.MAX_SIZE,
|
189 |
+
orig_w=cap.shape[1],
|
190 |
+
orig_h=cap.shape[0])
|
191 |
+
zoom_cap = capture.crop_capture(cap, zoom_config)
|
192 |
+
return zoom_cap
|
193 |
+
|
194 |
+
def __getitem__(self, index):
|
195 |
+
assert self.opt.k_size == 1
|
196 |
+
query_cap, nn_caps = self.sfm_dataset.get_query_with_knn(index)
|
197 |
+
nn_cap = nn_caps[0]
|
198 |
+
if self.need_rotation:
|
199 |
+
query_cap, nn_cap = self.augment_with_rotation(query_cap, nn_cap)
|
200 |
+
|
201 |
+
# find seed
|
202 |
+
seed_corr = self.get_seed_corr(nn_cap, query_cap)
|
203 |
+
if seed_corr is None:
|
204 |
+
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
205 |
+
|
206 |
+
# crop cap
|
207 |
+
s = np.random.choice(self.zooms)
|
208 |
+
nn_zoom_cap = self.get_zoomed_cap(nn_cap, seed_corr[:2], s, 0)
|
209 |
+
query_zoom_cap = self.get_zoomed_cap(query_cap, seed_corr[2:], s, self.zoom_jitter)
|
210 |
+
assert nn_zoom_cap.shape == query_zoom_cap.shape == (constants.MAX_SIZE, constants.MAX_SIZE)
|
211 |
+
corrs = self.get_corrs(query_zoom_cap, nn_zoom_cap)
|
212 |
+
if corrs is None or corrs.shape[0] < self.num_kp:
|
213 |
+
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
214 |
+
shuffle = np.random.permutation(corrs.shape[0])
|
215 |
+
corrs = np.take(corrs, shuffle, axis=0)
|
216 |
+
corrs = self._trim_corrs(corrs)
|
217 |
+
|
218 |
+
# flip augmentation
|
219 |
+
if np.random.uniform() < 0.5:
|
220 |
+
corrs[:, 0] = constants.MAX_SIZE - 1 - corrs[:, 0]
|
221 |
+
corrs[:, 2] = constants.MAX_SIZE - 1 - corrs[:, 2]
|
222 |
+
sbs_img = two_images_side_by_side(np.fliplr(query_zoom_cap.image), np.fliplr(nn_zoom_cap.image))
|
223 |
+
else:
|
224 |
+
sbs_img = two_images_side_by_side(query_zoom_cap.image, nn_zoom_cap.image)
|
225 |
+
|
226 |
+
corrs[:, 2] += constants.MAX_SIZE
|
227 |
+
corrs /= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE])
|
228 |
+
assert (0.0 <= corrs[:, 0]).all() and (corrs[:, 0] <= 0.5).all()
|
229 |
+
assert (0.0 <= corrs[:, 1]).all() and (corrs[:, 1] <= 1.0).all()
|
230 |
+
assert (0.5 <= corrs[:, 2]).all() and (corrs[:, 2] <= 1.0).all()
|
231 |
+
assert (0.0 <= corrs[:, 3]).all() and (corrs[:, 3] <= 1.0).all()
|
232 |
+
out = {
|
233 |
+
'image': tvtf.normalize(tvtf.to_tensor(sbs_img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
234 |
+
'corrs': torch.from_numpy(corrs).float(),
|
235 |
+
}
|
236 |
+
if self.bidirectional:
|
237 |
+
out['queries'] = torch.from_numpy(np.concatenate([corrs[:, :2], corrs[:, 2:]], axis=0)).float()
|
238 |
+
out['targets'] = torch.from_numpy(np.concatenate([corrs[:, 2:], corrs[:, :2]], axis=0)).float()
|
239 |
+
else:
|
240 |
+
out['queries'] = torch.from_numpy(corrs[:, :2]).float()
|
241 |
+
out['targets'] = torch.from_numpy(corrs[:, 2:]).float()
|
242 |
+
|
243 |
+
return out
|
third_party/COTR/COTR/datasets/megadepth_dataset.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
dataset specific layer for megadepth
|
3 |
+
'''
|
4 |
+
|
5 |
+
import os
|
6 |
+
import json
|
7 |
+
import random
|
8 |
+
from collections import namedtuple
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from COTR.datasets import colmap_helper
|
13 |
+
from COTR.global_configs import dataset_config
|
14 |
+
from COTR.sfm_scenes import knn_search
|
15 |
+
from COTR.utils import debug_utils, utils, constants
|
16 |
+
|
17 |
+
SceneCapIndex = namedtuple('SceneCapIndex', ['scene_index', 'capture_index'])
|
18 |
+
|
19 |
+
|
20 |
+
def prefix_of_img_path_for_magedepth(img_path):
|
21 |
+
'''
|
22 |
+
get the prefix for image of megadepth dataset
|
23 |
+
'''
|
24 |
+
prefix = os.path.abspath(os.path.join(img_path, '../../../..')) + '/'
|
25 |
+
return prefix
|
26 |
+
|
27 |
+
|
28 |
+
class MegadepthSceneDataBase():
|
29 |
+
scenes = {}
|
30 |
+
knn_engine_dict = {}
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def _load_scene(cls, opt, scene_dir_dict):
|
34 |
+
if scene_dir_dict['scene_dir'] not in cls.scenes:
|
35 |
+
if opt.info_level == 'rgb':
|
36 |
+
assert 0
|
37 |
+
elif opt.info_level == 'rgbd':
|
38 |
+
scene_dir = scene_dir_dict['scene_dir']
|
39 |
+
images_dir = scene_dir_dict['image_dir']
|
40 |
+
depth_dir = scene_dir_dict['depth_dir']
|
41 |
+
scene = colmap_helper.ColmapWithDepthAsciiReader.read_sfm_scene_given_valid_list_path(scene_dir, images_dir, depth_dir, dataset_config[opt.dataset_name]['valid_list_json'], opt.crop_cam)
|
42 |
+
if opt.use_ram:
|
43 |
+
scene.read_data_to_ram(['image', 'depth'])
|
44 |
+
else:
|
45 |
+
raise ValueError()
|
46 |
+
knn_engine = knn_search.ReprojRatioKnnSearch(scene)
|
47 |
+
cls.scenes[scene_dir_dict['scene_dir']] = scene
|
48 |
+
cls.knn_engine_dict[scene_dir_dict['scene_dir']] = knn_engine
|
49 |
+
else:
|
50 |
+
pass
|
51 |
+
|
52 |
+
|
53 |
+
class MegadepthDataset():
|
54 |
+
|
55 |
+
def __init__(self, opt, dataset_type):
|
56 |
+
assert dataset_type in ['train', 'val', 'test']
|
57 |
+
assert len(opt.scenes_name_list) > 0
|
58 |
+
self.opt = opt
|
59 |
+
self.dataset_type = dataset_type
|
60 |
+
self.use_ram = opt.use_ram
|
61 |
+
self.scenes_name_list = opt.scenes_name_list
|
62 |
+
self.scenes = None
|
63 |
+
self.knn_engine_list = None
|
64 |
+
self.total_caps_set = None
|
65 |
+
self.query_caps_set = None
|
66 |
+
self.db_caps_set = None
|
67 |
+
self.img_path_to_scene_cap_index_dict = {}
|
68 |
+
self.scene_index_to_db_caps_mask_dict = {}
|
69 |
+
self._load_scenes()
|
70 |
+
|
71 |
+
@property
|
72 |
+
def num_scenes(self):
|
73 |
+
return len(self.scenes)
|
74 |
+
|
75 |
+
@property
|
76 |
+
def num_queries(self):
|
77 |
+
return len(self.query_caps_set)
|
78 |
+
|
79 |
+
@property
|
80 |
+
def num_db(self):
|
81 |
+
return len(self.db_caps_set)
|
82 |
+
|
83 |
+
def get_scene_cap_index_by_index(self, index):
|
84 |
+
assert index < len(self.query_caps_set)
|
85 |
+
img_path = sorted(list(self.query_caps_set))[index]
|
86 |
+
scene_cap_index = self.img_path_to_scene_cap_index_dict[img_path]
|
87 |
+
return scene_cap_index
|
88 |
+
|
89 |
+
def _get_common_subset_caps_from_json(self, json_path, total_caps):
|
90 |
+
prefix = prefix_of_img_path_for_magedepth(list(total_caps)[0])
|
91 |
+
with open(json_path, 'r') as f:
|
92 |
+
common_caps = [prefix + cap for cap in json.load(f)]
|
93 |
+
common_caps = set(total_caps) & set(common_caps)
|
94 |
+
return common_caps
|
95 |
+
|
96 |
+
def _extend_img_path_to_scene_cap_index_dict(self, img_path_to_cap_index_dict, scene_id):
|
97 |
+
for key in img_path_to_cap_index_dict.keys():
|
98 |
+
self.img_path_to_scene_cap_index_dict[key] = SceneCapIndex(scene_id, img_path_to_cap_index_dict[key])
|
99 |
+
|
100 |
+
def _create_scene_index_to_db_caps_mask_dict(self, db_caps_set):
|
101 |
+
scene_index_to_db_caps_mask_dict = {}
|
102 |
+
for cap in db_caps_set:
|
103 |
+
scene_id, cap_id = self.img_path_to_scene_cap_index_dict[cap]
|
104 |
+
if scene_id not in scene_index_to_db_caps_mask_dict:
|
105 |
+
scene_index_to_db_caps_mask_dict[scene_id] = []
|
106 |
+
scene_index_to_db_caps_mask_dict[scene_id].append(cap_id)
|
107 |
+
for _k, _v in scene_index_to_db_caps_mask_dict.items():
|
108 |
+
scene_index_to_db_caps_mask_dict[_k] = np.array(sorted(_v))
|
109 |
+
return scene_index_to_db_caps_mask_dict
|
110 |
+
|
111 |
+
def _load_scenes(self):
|
112 |
+
scenes = []
|
113 |
+
knn_engine_list = []
|
114 |
+
total_caps_set = set()
|
115 |
+
for scene_id, scene_dir_dict in enumerate(self.scenes_name_list):
|
116 |
+
MegadepthSceneDataBase._load_scene(self.opt, scene_dir_dict)
|
117 |
+
scene = MegadepthSceneDataBase.scenes[scene_dir_dict['scene_dir']]
|
118 |
+
knn_engine = MegadepthSceneDataBase.knn_engine_dict[scene_dir_dict['scene_dir']]
|
119 |
+
total_caps_set = total_caps_set | set(scene.img_path_to_index_dict.keys())
|
120 |
+
self._extend_img_path_to_scene_cap_index_dict(scene.img_path_to_index_dict, scene_id)
|
121 |
+
scenes.append(scene)
|
122 |
+
knn_engine_list.append(knn_engine)
|
123 |
+
self.scenes = scenes
|
124 |
+
self.knn_engine_list = knn_engine_list
|
125 |
+
self.total_caps_set = total_caps_set
|
126 |
+
self.query_caps_set = self._get_common_subset_caps_from_json(dataset_config[self.opt.dataset_name][f'{self.dataset_type}_json'], total_caps_set)
|
127 |
+
self.db_caps_set = self._get_common_subset_caps_from_json(dataset_config[self.opt.dataset_name]['train_json'], total_caps_set)
|
128 |
+
self.scene_index_to_db_caps_mask_dict = self._create_scene_index_to_db_caps_mask_dict(self.db_caps_set)
|
129 |
+
|
130 |
+
def get_query_with_knn(self, index):
|
131 |
+
scene_index, cap_index = self.get_scene_cap_index_by_index(index)
|
132 |
+
query_cap = self.scenes[scene_index].captures[cap_index]
|
133 |
+
knn_engine = self.knn_engine_list[scene_index]
|
134 |
+
if scene_index in self.scene_index_to_db_caps_mask_dict:
|
135 |
+
db_mask = self.scene_index_to_db_caps_mask_dict[scene_index]
|
136 |
+
else:
|
137 |
+
db_mask = None
|
138 |
+
pool = knn_engine.get_knn(query_cap, self.opt.pool_size, db_mask=db_mask)
|
139 |
+
nn_caps = random.sample(pool, min(len(pool), self.opt.k_size))
|
140 |
+
return query_cap, nn_caps
|
third_party/COTR/COTR/global_configs/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
|
4 |
+
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
|
5 |
+
with open(os.path.join(__location__, 'dataset_config.json'), 'r') as f:
|
6 |
+
dataset_config = json.load(f)
|
7 |
+
with open(os.path.join(__location__, 'commons.json'), 'r') as f:
|
8 |
+
general_config = json.load(f)
|
9 |
+
# assert os.path.isdir(general_config['out']), f'Please create {general_config["out"]}'
|
10 |
+
# assert os.path.isdir(general_config['tb_out']), f'Please create {general_config["tb_out"]}'
|
third_party/COTR/COTR/global_configs/commons.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"out": "../../out", "tb_out": "../../tb_out"}
|
third_party/COTR/COTR/global_configs/dataset_config.json
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"megadepth": {
|
3 |
+
"valid_list_json": "/media/jiangwei/data_ssd/MegaDepth_v1_SfM/megadepth_valid_list.json",
|
4 |
+
"train_json": "/media/jiangwei/data_ssd/MegaDepth_v1_SfM/megadepth_train.json",
|
5 |
+
"val_json": "/media/jiangwei/data_ssd/MegaDepth_v1_SfM/megadepth_val.json",
|
6 |
+
"test_json": "/media/jiangwei/data_ssd/MegaDepth_v1_SfM/megadepth_test.json",
|
7 |
+
"scene_dir": "/media/jiangwei/data_ssd/MegaDepth_v1_SfM/{0}/sparse/manhattan/{1}_rectified/sparse",
|
8 |
+
"image_dir": "/media/jiangwei/data_ssd/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/imgs",
|
9 |
+
"depth_dir": "/media/jiangwei/data_ssd/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/depths"
|
10 |
+
},
|
11 |
+
|
12 |
+
"megadepth_sushi": {
|
13 |
+
"valid_list_json": "/scratch/dataset/megadepth/MegaDepth_v1_SfM/megadepth_valid_list.json",
|
14 |
+
"train_json": "/scratch/programs/COTR/sample_data/megadepth_train.json",
|
15 |
+
"val_json": "/scratch/programs/COTR/sample_data/megadepth_val.json",
|
16 |
+
"test_json": "/scratch/dataset/megadepth/MegaDepth_v1_SfM/megadepth_test.json",
|
17 |
+
"scene_dir": "/scratch/dataset/megadepth/MegaDepth_v1_SfM/{0}/sparse/manhattan/{1}_rectified/sparse",
|
18 |
+
"image_dir": "/scratch/dataset/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/imgs",
|
19 |
+
"depth_dir": "/scratch/dataset/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/depths"
|
20 |
+
},
|
21 |
+
|
22 |
+
"megadepth_sockeye": {
|
23 |
+
"valid_list_json": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1_SfM/megadepth_valid_list.json",
|
24 |
+
"train_json": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1_SfM/megadepth_train.json",
|
25 |
+
"val_json": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1_SfM/megadepth_val.json",
|
26 |
+
"test_json": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1_SfM/megadepth_test.json",
|
27 |
+
"scene_dir": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1_SfM/{0}/sparse/manhattan/{1}_rectified/sparse",
|
28 |
+
"image_dir": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/imgs",
|
29 |
+
"depth_dir": "/project/pr-kmyi-1/jiangwei/datasets/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/depths"
|
30 |
+
},
|
31 |
+
|
32 |
+
"megadepth_snubfin": {
|
33 |
+
"valid_list_json": "/ubc/cs/research/kmyi/datasets/megadepth/MegaDepth_v1_SfM/megadepth_valid_list.json",
|
34 |
+
"train_json": "/ubc/cs/research/kmyi/jw221/programs/COTR/sample_data/megadepth_train.json",
|
35 |
+
"val_json": "/ubc/cs/research/kmyi/jw221/programs/COTR/sample_data/megadepth_val.json",
|
36 |
+
"test_json": "/ubc/cs/research/kmyi/datasets/megadepth/MegaDepth_v1_SfM/megadepth_test.json",
|
37 |
+
"scene_dir": "/ubc/cs/research/kmyi/datasets/megadepth/MegaDepth_v1_SfM/{0}/sparse/manhattan/{1}_rectified/sparse",
|
38 |
+
"image_dir": "/ubc/cs/research/kmyi/datasets/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/imgs",
|
39 |
+
"depth_dir": "/ubc/cs/research/kmyi/datasets/megadepth/MegaDepth_v1/phoenix/S6/zl548/MegaDepth_v1/{0}/dense{1}/depths"
|
40 |
+
}
|
41 |
+
}
|
third_party/COTR/COTR/inference/inference_helper.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torchvision.transforms import functional as tvtf
|
7 |
+
from tqdm import tqdm
|
8 |
+
import PIL
|
9 |
+
|
10 |
+
from COTR.utils import utils, debug_utils
|
11 |
+
from COTR.utils.constants import MAX_SIZE
|
12 |
+
from COTR.cameras.capture import crop_center_max_np, pad_to_square_np
|
13 |
+
from COTR.utils.utils import ImagePatch
|
14 |
+
|
15 |
+
THRESHOLD_SPARSE = 0.02
|
16 |
+
THRESHOLD_PIXELS_RELATIVE = 0.02
|
17 |
+
BASE_ZOOM = 1.0
|
18 |
+
THRESHOLD_AREA = 0.02
|
19 |
+
LARGE_GPU = True
|
20 |
+
|
21 |
+
|
22 |
+
def find_prediction_loop(arr):
|
23 |
+
'''
|
24 |
+
loop ends at last element
|
25 |
+
'''
|
26 |
+
assert arr.shape[1] == 2, 'requires shape (N, 2)'
|
27 |
+
start_index = np.where(np.prod(arr[:-1] == arr[-1], axis=1))[0][0]
|
28 |
+
return arr[start_index:-1]
|
29 |
+
|
30 |
+
|
31 |
+
def two_images_side_by_side(img_a, img_b):
|
32 |
+
assert img_a.shape == img_b.shape, f'{img_a.shape} vs {img_b.shape}'
|
33 |
+
assert img_a.dtype == img_b.dtype
|
34 |
+
h, w, c = img_a.shape
|
35 |
+
canvas = np.zeros((h, 2 * w, c), dtype=img_a.dtype)
|
36 |
+
canvas[:, 0 * w:1 * w, :] = img_a
|
37 |
+
canvas[:, 1 * w:2 * w, :] = img_b
|
38 |
+
return canvas
|
39 |
+
|
40 |
+
|
41 |
+
def to_square_patches(img):
|
42 |
+
patches = []
|
43 |
+
h, w, _ = img.shape
|
44 |
+
short = size = min(h, w)
|
45 |
+
long = max(h, w)
|
46 |
+
if long == short:
|
47 |
+
patch_0 = ImagePatch(img[:size, :size], 0, 0, size, size, w, h)
|
48 |
+
patches = [patch_0]
|
49 |
+
elif long <= size * 2:
|
50 |
+
warnings.warn('Spatial smoothness in dense optical flow is lost, but sparse matching and triangulation should be fine')
|
51 |
+
patch_0 = ImagePatch(img[:size, :size], 0, 0, size, size, w, h)
|
52 |
+
patch_1 = ImagePatch(img[-size:, -size:], w - size, h - size, size, size, w, h)
|
53 |
+
patches = [patch_0, patch_1]
|
54 |
+
# patches += subdivide_patch(patch_0)
|
55 |
+
# patches += subdivide_patch(patch_1)
|
56 |
+
else:
|
57 |
+
raise NotImplementedError
|
58 |
+
return patches
|
59 |
+
|
60 |
+
|
61 |
+
def merge_flow_patches(corrs):
|
62 |
+
confidence = np.ones([corrs[0].oh, corrs[0].ow]) * 100
|
63 |
+
flow = np.zeros([corrs[0].oh, corrs[0].ow, 2])
|
64 |
+
cmap = np.ones([corrs[0].oh, corrs[0].ow]) * -1
|
65 |
+
for i, c in enumerate(corrs):
|
66 |
+
temp = np.ones([c.oh, c.ow]) * 100
|
67 |
+
temp[c.y:c.y + c.h, c.x:c.x + c.w] = c.patch[..., 2]
|
68 |
+
tempf = np.zeros([c.oh, c.ow, 2])
|
69 |
+
tempf[c.y:c.y + c.h, c.x:c.x + c.w] = c.patch[..., :2]
|
70 |
+
min_ind = np.stack([temp, confidence], axis=-1).argmin(axis=-1)
|
71 |
+
min_ind = min_ind == 0
|
72 |
+
confidence[min_ind] = temp[min_ind]
|
73 |
+
flow[min_ind] = tempf[min_ind]
|
74 |
+
cmap[min_ind] = i
|
75 |
+
return flow, confidence, cmap
|
76 |
+
|
77 |
+
|
78 |
+
def get_patch_centered_at(img, pos, scale=1.0, return_content=True, img_shape=None):
|
79 |
+
'''
|
80 |
+
pos - [x, y]
|
81 |
+
'''
|
82 |
+
if img_shape is None:
|
83 |
+
img_shape = img.shape
|
84 |
+
h, w, _ = img_shape
|
85 |
+
short = min(h, w)
|
86 |
+
scale = np.clip(scale, 0.0, 1.0)
|
87 |
+
size = short * scale
|
88 |
+
size = int((size // 2) * 2)
|
89 |
+
lu_y = int(pos[1] - size // 2)
|
90 |
+
lu_x = int(pos[0] - size // 2)
|
91 |
+
if lu_y < 0:
|
92 |
+
lu_y -= lu_y
|
93 |
+
if lu_x < 0:
|
94 |
+
lu_x -= lu_x
|
95 |
+
if lu_y + size > h:
|
96 |
+
lu_y -= (lu_y + size) - (h)
|
97 |
+
if lu_x + size > w:
|
98 |
+
lu_x -= (lu_x + size) - (w)
|
99 |
+
if return_content:
|
100 |
+
return ImagePatch(img[lu_y:lu_y + size, lu_x:lu_x + size], lu_x, lu_y, size, size, w, h)
|
101 |
+
else:
|
102 |
+
return ImagePatch(None, lu_x, lu_y, size, size, w, h)
|
103 |
+
|
104 |
+
|
105 |
+
def cotr_patch_flow_exhaustive(model, patches_a, patches_b):
|
106 |
+
def one_pass(model, img_a, img_b):
|
107 |
+
device = next(model.parameters()).device
|
108 |
+
assert img_a.shape[0] == img_a.shape[1]
|
109 |
+
assert img_b.shape[0] == img_b.shape[1]
|
110 |
+
img_a = np.array(PIL.Image.fromarray(img_a).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
|
111 |
+
img_b = np.array(PIL.Image.fromarray(img_b).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
|
112 |
+
img = two_images_side_by_side(img_a, img_b)
|
113 |
+
img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float()[None]
|
114 |
+
img = img.to(device)
|
115 |
+
|
116 |
+
q_list = []
|
117 |
+
for i in range(MAX_SIZE):
|
118 |
+
queries = []
|
119 |
+
for j in range(MAX_SIZE * 2):
|
120 |
+
queries.append([(j) / (MAX_SIZE * 2), i / MAX_SIZE])
|
121 |
+
queries = np.array(queries)
|
122 |
+
q_list.append(queries)
|
123 |
+
if LARGE_GPU:
|
124 |
+
try:
|
125 |
+
queries = torch.from_numpy(np.concatenate(q_list))[None].float().to(device)
|
126 |
+
out = model.forward(img, queries)['pred_corrs'].detach().cpu().numpy()[0]
|
127 |
+
out_list = out.reshape(MAX_SIZE, MAX_SIZE * 2, -1)
|
128 |
+
except:
|
129 |
+
assert 0, 'set LARGE_GPU to False'
|
130 |
+
else:
|
131 |
+
out_list = []
|
132 |
+
for q in q_list:
|
133 |
+
queries = torch.from_numpy(q)[None].float().to(device)
|
134 |
+
out = model.forward(img, queries)['pred_corrs'].detach().cpu().numpy()[0]
|
135 |
+
out_list.append(out)
|
136 |
+
out_list = np.array(out_list)
|
137 |
+
in_grid = torch.from_numpy(np.array(q_list)).float()[None] * 2 - 1
|
138 |
+
out_grid = torch.from_numpy(out_list).float()[None] * 2 - 1
|
139 |
+
cycle_grid = torch.nn.functional.grid_sample(out_grid.permute(0, 3, 1, 2), out_grid).permute(0, 2, 3, 1)
|
140 |
+
confidence = torch.norm(cycle_grid[0, ...] - in_grid[0, ...], dim=-1)
|
141 |
+
corr = out_grid[0].clone()
|
142 |
+
corr[:, :MAX_SIZE, 0] = corr[:, :MAX_SIZE, 0] * 2 - 1
|
143 |
+
corr[:, MAX_SIZE:, 0] = corr[:, MAX_SIZE:, 0] * 2 + 1
|
144 |
+
corr = torch.cat([corr, confidence[..., None]], dim=-1).numpy()
|
145 |
+
return corr[:, :MAX_SIZE, :], corr[:, MAX_SIZE:, :]
|
146 |
+
corrs_a = []
|
147 |
+
corrs_b = []
|
148 |
+
|
149 |
+
for p_i in patches_a:
|
150 |
+
for p_j in patches_b:
|
151 |
+
c_i, c_j = one_pass(model, p_i.patch, p_j.patch)
|
152 |
+
base_corners = np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]])
|
153 |
+
real_corners_j = (np.array([[p_j.x, p_j.y], [p_j.x + p_j.w, p_j.y], [p_j.x + p_j.w, p_j.y + p_j.h], [p_j.x, p_j.y + p_j.h]]) / np.array([p_j.ow, p_j.oh])) * 2 + np.array([-1, -1])
|
154 |
+
real_corners_i = (np.array([[p_i.x, p_i.y], [p_i.x + p_i.w, p_i.y], [p_i.x + p_i.w, p_i.y + p_i.h], [p_i.x, p_i.y + p_i.h]]) / np.array([p_i.ow, p_i.oh])) * 2 + np.array([-1, -1])
|
155 |
+
T_i = cv2.getAffineTransform(base_corners[:3].astype(np.float32), real_corners_j[:3].astype(np.float32))
|
156 |
+
T_j = cv2.getAffineTransform(base_corners[:3].astype(np.float32), real_corners_i[:3].astype(np.float32))
|
157 |
+
c_i[..., :2] = c_i[..., :2] @ T_i[:2, :2] + T_i[:, 2]
|
158 |
+
c_j[..., :2] = c_j[..., :2] @ T_j[:2, :2] + T_j[:, 2]
|
159 |
+
c_i = utils.float_image_resize(c_i, (p_i.h, p_i.w))
|
160 |
+
c_j = utils.float_image_resize(c_j, (p_j.h, p_j.w))
|
161 |
+
c_i = ImagePatch(c_i, p_i.x, p_i.y, p_i.w, p_i.h, p_i.ow, p_i.oh)
|
162 |
+
c_j = ImagePatch(c_j, p_j.x, p_j.y, p_j.w, p_j.h, p_j.ow, p_j.oh)
|
163 |
+
corrs_a.append(c_i)
|
164 |
+
corrs_b.append(c_j)
|
165 |
+
return corrs_a, corrs_b
|
166 |
+
|
167 |
+
|
168 |
+
def cotr_flow(model, img_a, img_b):
|
169 |
+
# assert img_a.shape[0] == img_a.shape[1]
|
170 |
+
# assert img_b.shape[0] == img_b.shape[1]
|
171 |
+
patches_a = to_square_patches(img_a)
|
172 |
+
patches_b = to_square_patches(img_b)
|
173 |
+
|
174 |
+
corrs_a, corrs_b = cotr_patch_flow_exhaustive(model, patches_a, patches_b)
|
175 |
+
corr_a, con_a, cmap_a = merge_flow_patches(corrs_a)
|
176 |
+
corr_b, con_b, cmap_b = merge_flow_patches(corrs_b)
|
177 |
+
|
178 |
+
resample_a = utils.torch_img_to_np_img(torch.nn.functional.grid_sample(utils.np_img_to_torch_img(img_b)[None].float(),
|
179 |
+
torch.from_numpy(corr_a)[None].float())[0])
|
180 |
+
resample_b = utils.torch_img_to_np_img(torch.nn.functional.grid_sample(utils.np_img_to_torch_img(img_a)[None].float(),
|
181 |
+
torch.from_numpy(corr_b)[None].float())[0])
|
182 |
+
return corr_a, con_a, resample_a, corr_b, con_b, resample_b
|
183 |
+
|
184 |
+
|
185 |
+
def cotr_corr_base(model, img_a, img_b, queries_a):
|
186 |
+
def one_pass(model, img_a, img_b, queries):
|
187 |
+
device = next(model.parameters()).device
|
188 |
+
assert img_a.shape[0] == img_a.shape[1]
|
189 |
+
assert img_b.shape[0] == img_b.shape[1]
|
190 |
+
img_a = np.array(PIL.Image.fromarray(img_a).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
|
191 |
+
img_b = np.array(PIL.Image.fromarray(img_b).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
|
192 |
+
img = two_images_side_by_side(img_a, img_b)
|
193 |
+
img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float()[None]
|
194 |
+
img = img.to(device)
|
195 |
+
|
196 |
+
queries = torch.from_numpy(queries)[None].float().to(device)
|
197 |
+
out = model.forward(img, queries)['pred_corrs'].clone().detach()
|
198 |
+
cycle = model.forward(img, out)['pred_corrs'].clone().detach()
|
199 |
+
|
200 |
+
queries = queries.cpu().numpy()[0]
|
201 |
+
out = out.cpu().numpy()[0]
|
202 |
+
cycle = cycle.cpu().numpy()[0]
|
203 |
+
conf = np.linalg.norm(queries - cycle, axis=1, keepdims=True)
|
204 |
+
return np.concatenate([out, conf], axis=1)
|
205 |
+
|
206 |
+
patches_a = to_square_patches(img_a)
|
207 |
+
patches_b = to_square_patches(img_b)
|
208 |
+
pred_list = []
|
209 |
+
|
210 |
+
for p_i in patches_a:
|
211 |
+
for p_j in patches_b:
|
212 |
+
normalized_queries_a = queries_a.copy()
|
213 |
+
mask = (normalized_queries_a[:, 0] >= p_i.x) & (normalized_queries_a[:, 1] >= p_i.y) & (normalized_queries_a[:, 0] <= p_i.x + p_i.w) & (normalized_queries_a[:, 1] <= p_i.y + p_i.h)
|
214 |
+
normalized_queries_a[:, 0] -= p_i.x
|
215 |
+
normalized_queries_a[:, 1] -= p_i.y
|
216 |
+
normalized_queries_a[:, 0] /= 2 * p_i.w
|
217 |
+
normalized_queries_a[:, 1] /= p_i.h
|
218 |
+
pred = one_pass(model, p_i.patch, p_j.patch, normalized_queries_a)
|
219 |
+
pred[~mask, 2] = np.inf
|
220 |
+
pred[:, 0] -= 0.5
|
221 |
+
pred[:, 0] *= 2 * p_j.w
|
222 |
+
pred[:, 0] += p_j.x
|
223 |
+
pred[:, 1] *= p_j.h
|
224 |
+
pred[:, 1] += p_j.y
|
225 |
+
pred_list.append(pred)
|
226 |
+
|
227 |
+
pred_list = np.stack(pred_list).transpose(1, 0, 2)
|
228 |
+
out = []
|
229 |
+
for item in pred_list:
|
230 |
+
out.append(item[np.argmin(item[..., 2], axis=0)])
|
231 |
+
out = np.array(out)[..., :2]
|
232 |
+
return np.concatenate([queries_a, out], axis=1)
|
233 |
+
|
234 |
+
|
235 |
+
try:
|
236 |
+
from vispy import gloo
|
237 |
+
from vispy import app
|
238 |
+
from vispy.util.ptime import time
|
239 |
+
from scipy.spatial import Delaunay
|
240 |
+
from vispy.gloo.wrappers import read_pixels
|
241 |
+
|
242 |
+
app.use_app('glfw')
|
243 |
+
|
244 |
+
|
245 |
+
vertex_shader = """
|
246 |
+
attribute vec4 color;
|
247 |
+
attribute vec2 position;
|
248 |
+
varying vec4 v_color;
|
249 |
+
void main()
|
250 |
+
{
|
251 |
+
gl_Position = vec4(position, 0.0, 1.0);
|
252 |
+
v_color = color;
|
253 |
+
} """
|
254 |
+
|
255 |
+
fragment_shader = """
|
256 |
+
varying vec4 v_color;
|
257 |
+
void main()
|
258 |
+
{
|
259 |
+
gl_FragColor = v_color;
|
260 |
+
} """
|
261 |
+
|
262 |
+
|
263 |
+
class Canvas(app.Canvas):
|
264 |
+
def __init__(self, mesh, color, size):
|
265 |
+
# We hide the canvas upon creation.
|
266 |
+
app.Canvas.__init__(self, show=False, size=size)
|
267 |
+
self._t0 = time()
|
268 |
+
# Texture where we render the scene.
|
269 |
+
self._rendertex = gloo.Texture2D(shape=self.size[::-1] + (4,), internalformat='rgba32f')
|
270 |
+
# FBO.
|
271 |
+
self._fbo = gloo.FrameBuffer(self._rendertex,
|
272 |
+
gloo.RenderBuffer(self.size[::-1]))
|
273 |
+
# Regular program that will be rendered to the FBO.
|
274 |
+
self.program = gloo.Program(vertex_shader, fragment_shader)
|
275 |
+
self.program["position"] = mesh
|
276 |
+
self.program['color'] = color
|
277 |
+
# We manually draw the hidden canvas.
|
278 |
+
self.update()
|
279 |
+
|
280 |
+
def on_draw(self, event):
|
281 |
+
# Render in the FBO.
|
282 |
+
with self._fbo:
|
283 |
+
gloo.clear('black')
|
284 |
+
gloo.set_viewport(0, 0, *self.size)
|
285 |
+
self.program.draw()
|
286 |
+
# Retrieve the contents of the FBO texture.
|
287 |
+
self.im = read_pixels((0, 0, self.size[0], self.size[1]), True, out_type='float')
|
288 |
+
self._time = time() - self._t0
|
289 |
+
# Immediately exit the application.
|
290 |
+
app.quit()
|
291 |
+
|
292 |
+
|
293 |
+
def triangulate_corr(corr, from_shape, to_shape):
|
294 |
+
corr = corr.copy()
|
295 |
+
to_shape = to_shape[:2]
|
296 |
+
from_shape = from_shape[:2]
|
297 |
+
corr = corr / np.concatenate([from_shape[::-1], to_shape[::-1]])
|
298 |
+
tri = Delaunay(corr[:, :2])
|
299 |
+
mesh = corr[:, :2][tri.simplices].astype(np.float32) * 2 - 1
|
300 |
+
mesh[..., 1] *= -1
|
301 |
+
color = corr[:, 2:][tri.simplices].astype(np.float32)
|
302 |
+
color = np.concatenate([color, np.ones_like(color[..., 0:2])], axis=-1)
|
303 |
+
c = Canvas(mesh.reshape(-1, 2), color.reshape(-1, 4), size=(from_shape[::-1]))
|
304 |
+
app.run()
|
305 |
+
render = c.im.copy()
|
306 |
+
render = render[..., :2]
|
307 |
+
render *= np.array(to_shape[::-1])
|
308 |
+
return render
|
309 |
+
except:
|
310 |
+
print('cannot use vispy, setting triangulate_corr as None')
|
311 |
+
triangulate_corr = None
|
third_party/COTR/COTR/inference/refinement_task.py
ADDED
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms import functional as tvtf
|
6 |
+
import imageio
|
7 |
+
import PIL
|
8 |
+
|
9 |
+
from COTR.inference.inference_helper import BASE_ZOOM, THRESHOLD_PIXELS_RELATIVE, get_patch_centered_at, two_images_side_by_side, find_prediction_loop
|
10 |
+
from COTR.utils import debug_utils, utils
|
11 |
+
from COTR.utils.constants import MAX_SIZE
|
12 |
+
from COTR.utils.utils import ImagePatch
|
13 |
+
|
14 |
+
|
15 |
+
class RefinementTask():
|
16 |
+
def __init__(self, image_from, image_to, loc_from, loc_to, area_from, area_to, converge_iters, zoom_ins, identifier=None):
|
17 |
+
self.identifier = identifier
|
18 |
+
self.image_from = image_from
|
19 |
+
self.image_to = image_to
|
20 |
+
self.loc_from = loc_from
|
21 |
+
self.best_loc_to = loc_to
|
22 |
+
self.cur_loc_to = loc_to
|
23 |
+
self.area_from = area_from
|
24 |
+
self.area_to = area_to
|
25 |
+
if self.area_from < self.area_to:
|
26 |
+
self.s_from = BASE_ZOOM
|
27 |
+
self.s_to = BASE_ZOOM * np.sqrt(self.area_to / self.area_from)
|
28 |
+
else:
|
29 |
+
self.s_to = BASE_ZOOM
|
30 |
+
self.s_from = BASE_ZOOM * np.sqrt(self.area_from / self.area_to)
|
31 |
+
|
32 |
+
self.cur_job = {}
|
33 |
+
self.status = 'unfinished'
|
34 |
+
self.result = 'unknown'
|
35 |
+
|
36 |
+
self.converge_iters = converge_iters
|
37 |
+
self.zoom_ins = zoom_ins
|
38 |
+
self.cur_zoom_idx = 0
|
39 |
+
self.cur_iter = 0
|
40 |
+
self.total_iter = 0
|
41 |
+
|
42 |
+
self.loc_to_at_zoom = []
|
43 |
+
self.loc_history = [loc_to]
|
44 |
+
self.all_loc_to_dict = {}
|
45 |
+
self.job_history = []
|
46 |
+
self.submitted = False
|
47 |
+
|
48 |
+
@property
|
49 |
+
def cur_zoom(self):
|
50 |
+
return self.zoom_ins[self.cur_zoom_idx]
|
51 |
+
|
52 |
+
@property
|
53 |
+
def confidence_scaling_factor(self):
|
54 |
+
if self.cur_zoom_idx > 0:
|
55 |
+
conf_scaling = float(self.cur_zoom) / float(self.zoom_ins[0])
|
56 |
+
else:
|
57 |
+
conf_scaling = 1.0
|
58 |
+
return conf_scaling
|
59 |
+
|
60 |
+
def peek(self):
|
61 |
+
assert self.status == 'unfinished'
|
62 |
+
patch_from = get_patch_centered_at(None, self.loc_from, scale=self.s_from * self.cur_zoom, return_content=False, img_shape=self.image_from.shape)
|
63 |
+
patch_to = get_patch_centered_at(None, self.cur_loc_to, scale=self.s_to * self.cur_zoom, return_content=False, img_shape=self.image_to.shape)
|
64 |
+
top_job = {'patch_from': patch_from,
|
65 |
+
'patch_to': patch_to,
|
66 |
+
'loc_from': self.loc_from,
|
67 |
+
'loc_to': self.cur_loc_to,
|
68 |
+
}
|
69 |
+
return top_job
|
70 |
+
|
71 |
+
def get_task_pilot(self, pilot):
|
72 |
+
assert self.status == 'unfinished'
|
73 |
+
patch_from = ImagePatch(None, pilot.cur_job['patch_from'].x, pilot.cur_job['patch_from'].y, pilot.cur_job['patch_from'].w, pilot.cur_job['patch_from'].h, pilot.cur_job['patch_from'].ow, pilot.cur_job['patch_from'].oh)
|
74 |
+
patch_to = ImagePatch(None, pilot.cur_job['patch_to'].x, pilot.cur_job['patch_to'].y, pilot.cur_job['patch_to'].w, pilot.cur_job['patch_to'].h, pilot.cur_job['patch_to'].ow, pilot.cur_job['patch_to'].oh)
|
75 |
+
query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float()
|
76 |
+
self.cur_job = {'patch_from': patch_from,
|
77 |
+
'patch_to': patch_to,
|
78 |
+
'loc_from': self.loc_from,
|
79 |
+
'loc_to': self.cur_loc_to,
|
80 |
+
'img': None,
|
81 |
+
}
|
82 |
+
self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w))
|
83 |
+
assert self.submitted == False
|
84 |
+
self.submitted = True
|
85 |
+
return None, query
|
86 |
+
|
87 |
+
def get_task_fast(self):
|
88 |
+
assert self.status == 'unfinished'
|
89 |
+
patch_from = get_patch_centered_at(self.image_from, self.loc_from, scale=self.s_from * self.cur_zoom, return_content=False)
|
90 |
+
patch_to = get_patch_centered_at(self.image_to, self.cur_loc_to, scale=self.s_to * self.cur_zoom, return_content=False)
|
91 |
+
query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float()
|
92 |
+
self.cur_job = {'patch_from': patch_from,
|
93 |
+
'patch_to': patch_to,
|
94 |
+
'loc_from': self.loc_from,
|
95 |
+
'loc_to': self.cur_loc_to,
|
96 |
+
'img': None,
|
97 |
+
}
|
98 |
+
|
99 |
+
self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w))
|
100 |
+
assert self.submitted == False
|
101 |
+
self.submitted = True
|
102 |
+
|
103 |
+
return None, query
|
104 |
+
|
105 |
+
def get_task(self):
|
106 |
+
assert self.status == 'unfinished'
|
107 |
+
patch_from = get_patch_centered_at(self.image_from, self.loc_from, scale=self.s_from * self.cur_zoom)
|
108 |
+
patch_to = get_patch_centered_at(self.image_to, self.cur_loc_to, scale=self.s_to * self.cur_zoom)
|
109 |
+
|
110 |
+
query = torch.from_numpy((np.array(self.loc_from) - np.array([patch_from.x, patch_from.y])) / np.array([patch_from.w * 2, patch_from.h]))[None].float()
|
111 |
+
|
112 |
+
img_from = patch_from.patch
|
113 |
+
img_to = patch_to.patch
|
114 |
+
assert img_from.shape[0] == img_from.shape[1]
|
115 |
+
assert img_to.shape[0] == img_to.shape[1]
|
116 |
+
|
117 |
+
img_from = np.array(PIL.Image.fromarray(img_from).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
|
118 |
+
img_to = np.array(PIL.Image.fromarray(img_to).resize((MAX_SIZE, MAX_SIZE), resample=PIL.Image.BILINEAR))
|
119 |
+
img = two_images_side_by_side(img_from, img_to)
|
120 |
+
img = tvtf.normalize(tvtf.to_tensor(img), (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)).float()
|
121 |
+
|
122 |
+
self.cur_job = {'patch_from': ImagePatch(None, patch_from.x, patch_from.y, patch_from.w, patch_from.h, patch_from.ow, patch_from.oh),
|
123 |
+
'patch_to': ImagePatch(None, patch_to.x, patch_to.y, patch_to.w, patch_to.h, patch_to.ow, patch_to.oh),
|
124 |
+
'loc_from': self.loc_from,
|
125 |
+
'loc_to': self.cur_loc_to,
|
126 |
+
}
|
127 |
+
|
128 |
+
self.job_history.append((patch_from.h, patch_from.w, patch_to.h, patch_to.w))
|
129 |
+
assert self.submitted == False
|
130 |
+
self.submitted = True
|
131 |
+
|
132 |
+
return img, query
|
133 |
+
|
134 |
+
def next_zoom(self):
|
135 |
+
if self.cur_zoom_idx >= len(self.zoom_ins) - 1:
|
136 |
+
self.status = 'finished'
|
137 |
+
if self.conclude() is None:
|
138 |
+
self.result = 'bad'
|
139 |
+
else:
|
140 |
+
self.result = 'good'
|
141 |
+
self.cur_zoom_idx += 1
|
142 |
+
self.cur_iter = 0
|
143 |
+
self.loc_to_at_zoom = []
|
144 |
+
|
145 |
+
def scale_to_loc(self, raw_to_loc):
|
146 |
+
raw_to_loc = raw_to_loc.copy()
|
147 |
+
patch_b = self.cur_job['patch_to']
|
148 |
+
raw_to_loc[0] = (raw_to_loc[0] - 0.5) * 2
|
149 |
+
loc_to = raw_to_loc * np.array([patch_b.w, patch_b.h])
|
150 |
+
loc_to = loc_to + np.array([patch_b.x, patch_b.y])
|
151 |
+
return loc_to
|
152 |
+
|
153 |
+
def step(self, raw_to_loc):
|
154 |
+
assert self.submitted == True
|
155 |
+
self.submitted = False
|
156 |
+
loc_to = self.scale_to_loc(raw_to_loc)
|
157 |
+
self.total_iter += 1
|
158 |
+
self.loc_to_at_zoom.append(loc_to)
|
159 |
+
self.cur_loc_to = loc_to
|
160 |
+
zoom_finished = False
|
161 |
+
if self.cur_zoom_idx == len(self.zoom_ins) - 1:
|
162 |
+
# converge at the last level
|
163 |
+
if len(self.loc_to_at_zoom) >= 2:
|
164 |
+
zoom_finished = np.prod(self.loc_to_at_zoom[:-1] == loc_to, axis=1, keepdims=True).any()
|
165 |
+
if self.cur_iter >= self.converge_iters - 1:
|
166 |
+
zoom_finished = True
|
167 |
+
self.cur_iter += 1
|
168 |
+
else:
|
169 |
+
# finish immediately for other levels
|
170 |
+
zoom_finished = True
|
171 |
+
if zoom_finished:
|
172 |
+
self.all_loc_to_dict[self.cur_zoom] = np.array(self.loc_to_at_zoom).copy()
|
173 |
+
last_level_loc_to = self.all_loc_to_dict[self.cur_zoom]
|
174 |
+
if len(last_level_loc_to) >= 2:
|
175 |
+
has_loop = np.prod(last_level_loc_to[:-1] == last_level_loc_to[-1], axis=1, keepdims=True).any()
|
176 |
+
if has_loop:
|
177 |
+
loop = find_prediction_loop(last_level_loc_to)
|
178 |
+
loc_to = loop.mean(axis=0)
|
179 |
+
self.loc_history.append(loc_to)
|
180 |
+
self.best_loc_to = loc_to
|
181 |
+
self.cur_loc_to = self.best_loc_to
|
182 |
+
self.next_zoom()
|
183 |
+
|
184 |
+
def conclude(self, force=False):
|
185 |
+
loc_history = np.array(self.loc_history)
|
186 |
+
if (force == False) and (max(loc_history.std(axis=0)) >= THRESHOLD_PIXELS_RELATIVE * max(*self.image_to.shape)):
|
187 |
+
return None
|
188 |
+
return np.concatenate([self.loc_from, self.best_loc_to])
|
189 |
+
|
190 |
+
def conclude_intermedia(self):
|
191 |
+
return np.concatenate([np.array(self.loc_history), np.array(self.job_history)], axis=1)
|
third_party/COTR/COTR/inference/sparse_engine.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Inference engine for sparse image pair correspondences
|
3 |
+
'''
|
4 |
+
|
5 |
+
import time
|
6 |
+
import random
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from COTR.inference.inference_helper import THRESHOLD_SPARSE, THRESHOLD_AREA, cotr_flow, cotr_corr_base
|
12 |
+
from COTR.inference.refinement_task import RefinementTask
|
13 |
+
from COTR.utils import debug_utils, utils
|
14 |
+
from COTR.cameras.capture import stretch_to_square_np
|
15 |
+
|
16 |
+
|
17 |
+
class SparseEngine():
|
18 |
+
def __init__(self, model, batch_size, mode='stretching'):
|
19 |
+
assert mode in ['stretching', 'tile']
|
20 |
+
self.model = model
|
21 |
+
self.batch_size = batch_size
|
22 |
+
self.total_tasks = 0
|
23 |
+
self.mode = mode
|
24 |
+
|
25 |
+
def form_batch(self, tasks, zoom=None):
|
26 |
+
counter = 0
|
27 |
+
task_ref = []
|
28 |
+
img_batch = []
|
29 |
+
query_batch = []
|
30 |
+
for t in tasks:
|
31 |
+
if t.status == 'unfinished' and t.submitted == False:
|
32 |
+
if zoom is not None and t.cur_zoom != zoom:
|
33 |
+
continue
|
34 |
+
task_ref.append(t)
|
35 |
+
img, query = t.get_task()
|
36 |
+
img_batch.append(img)
|
37 |
+
query_batch.append(query)
|
38 |
+
counter += 1
|
39 |
+
if counter >= self.batch_size:
|
40 |
+
break
|
41 |
+
if len(task_ref) == 0:
|
42 |
+
return [], [], []
|
43 |
+
img_batch = torch.stack(img_batch)
|
44 |
+
query_batch = torch.stack(query_batch)
|
45 |
+
return task_ref, img_batch, query_batch
|
46 |
+
|
47 |
+
def infer_batch(self, img_batch, query_batch):
|
48 |
+
self.total_tasks += img_batch.shape[0]
|
49 |
+
device = next(self.model.parameters()).device
|
50 |
+
img_batch = img_batch.to(device)
|
51 |
+
query_batch = query_batch.to(device)
|
52 |
+
out = self.model(img_batch, query_batch)['pred_corrs'].clone().detach()
|
53 |
+
out = out.cpu().numpy()[:, 0, :]
|
54 |
+
if utils.has_nan(out):
|
55 |
+
raise ValueError('NaN in prediction')
|
56 |
+
return out
|
57 |
+
|
58 |
+
def conclude_tasks(self, tasks, return_idx=False, force=False,
|
59 |
+
offset_x_from=0,
|
60 |
+
offset_y_from=0,
|
61 |
+
offset_x_to=0,
|
62 |
+
offset_y_to=0,
|
63 |
+
img_a_shape=None,
|
64 |
+
img_b_shape=None):
|
65 |
+
corrs = []
|
66 |
+
idx = []
|
67 |
+
for t in tasks:
|
68 |
+
if t.status == 'finished':
|
69 |
+
out = t.conclude(force)
|
70 |
+
if out is not None:
|
71 |
+
corrs.append(np.array(out))
|
72 |
+
idx.append(t.identifier)
|
73 |
+
corrs = np.array(corrs)
|
74 |
+
idx = np.array(idx)
|
75 |
+
if corrs.shape[0] > 0:
|
76 |
+
corrs -= np.array([offset_x_from, offset_y_from, offset_x_to, offset_y_to])
|
77 |
+
if img_a_shape is not None and img_b_shape is not None and not force:
|
78 |
+
border_mask = np.prod(corrs < np.concatenate([img_a_shape[::-1], img_b_shape[::-1]]), axis=1)
|
79 |
+
border_mask = (np.prod(corrs > np.array([0, 0, 0, 0]), axis=1) * border_mask).astype(np.bool)
|
80 |
+
corrs = corrs[border_mask]
|
81 |
+
idx = idx[border_mask]
|
82 |
+
if return_idx:
|
83 |
+
return corrs, idx
|
84 |
+
return corrs
|
85 |
+
|
86 |
+
def num_finished_tasks(self, tasks):
|
87 |
+
counter = 0
|
88 |
+
for t in tasks:
|
89 |
+
if t.status == 'finished':
|
90 |
+
counter += 1
|
91 |
+
return counter
|
92 |
+
|
93 |
+
def num_good_tasks(self, tasks):
|
94 |
+
counter = 0
|
95 |
+
for t in tasks:
|
96 |
+
if t.result == 'good':
|
97 |
+
counter += 1
|
98 |
+
return counter
|
99 |
+
|
100 |
+
def gen_tasks_w_known_scale(self, img_a, img_b, queries_a, areas, zoom_ins=[1.0], converge_iters=1, max_corrs=1000):
|
101 |
+
assert self.mode == 'tile'
|
102 |
+
corr_a = cotr_corr_base(self.model, img_a, img_b, queries_a)
|
103 |
+
tasks = []
|
104 |
+
for c in corr_a:
|
105 |
+
tasks.append(RefinementTask(img_a, img_b, c[:2], c[2:], areas[0], areas[1], converge_iters, zoom_ins))
|
106 |
+
return tasks
|
107 |
+
|
108 |
+
def gen_tasks(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, force=False, areas=None):
|
109 |
+
if areas is not None:
|
110 |
+
assert queries_a is not None
|
111 |
+
assert force == True
|
112 |
+
assert max_corrs >= queries_a.shape[0]
|
113 |
+
return self.gen_tasks_w_known_scale(img_a, img_b, queries_a, areas, zoom_ins=zoom_ins, converge_iters=converge_iters, max_corrs=max_corrs)
|
114 |
+
if self.mode == 'stretching':
|
115 |
+
if img_a.shape[0] != img_a.shape[1] or img_b.shape[0] != img_b.shape[1]:
|
116 |
+
img_a_shape = img_a.shape
|
117 |
+
img_b_shape = img_b.shape
|
118 |
+
img_a_sq = stretch_to_square_np(img_a.copy())
|
119 |
+
img_b_sq = stretch_to_square_np(img_b.copy())
|
120 |
+
corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model,
|
121 |
+
img_a_sq,
|
122 |
+
img_b_sq
|
123 |
+
)
|
124 |
+
corr_a = utils.float_image_resize(corr_a, img_a_shape[:2])
|
125 |
+
con_a = utils.float_image_resize(con_a, img_a_shape[:2])
|
126 |
+
resample_a = utils.float_image_resize(resample_a, img_a_shape[:2])
|
127 |
+
corr_b = utils.float_image_resize(corr_b, img_b_shape[:2])
|
128 |
+
con_b = utils.float_image_resize(con_b, img_b_shape[:2])
|
129 |
+
resample_b = utils.float_image_resize(resample_b, img_b_shape[:2])
|
130 |
+
else:
|
131 |
+
corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model,
|
132 |
+
img_a,
|
133 |
+
img_b
|
134 |
+
)
|
135 |
+
elif self.mode == 'tile':
|
136 |
+
corr_a, con_a, resample_a, corr_b, con_b, resample_b = cotr_flow(self.model,
|
137 |
+
img_a,
|
138 |
+
img_b
|
139 |
+
)
|
140 |
+
else:
|
141 |
+
raise ValueError(f'unsupported mode: {self.mode}')
|
142 |
+
mask_a = con_a < THRESHOLD_SPARSE
|
143 |
+
mask_b = con_b < THRESHOLD_SPARSE
|
144 |
+
area_a = (con_a < THRESHOLD_AREA).sum() / mask_a.size
|
145 |
+
area_b = (con_b < THRESHOLD_AREA).sum() / mask_b.size
|
146 |
+
tasks = []
|
147 |
+
|
148 |
+
if queries_a is None:
|
149 |
+
index_a = np.where(mask_a)
|
150 |
+
index_a = np.array(index_a).T
|
151 |
+
index_a = index_a[np.random.choice(len(index_a), min(max_corrs, len(index_a)))]
|
152 |
+
index_b = np.where(mask_b)
|
153 |
+
index_b = np.array(index_b).T
|
154 |
+
index_b = index_b[np.random.choice(len(index_b), min(max_corrs, len(index_b)))]
|
155 |
+
for pos in index_a:
|
156 |
+
loc_from = pos[::-1]
|
157 |
+
loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
|
158 |
+
tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins))
|
159 |
+
for pos in index_b:
|
160 |
+
'''
|
161 |
+
trick: suppose to fix the query point location(loc_from),
|
162 |
+
but here it fixes the first guess(loc_to).
|
163 |
+
'''
|
164 |
+
loc_from = pos[::-1]
|
165 |
+
loc_to = (corr_b[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_a.shape[:2][::-1]
|
166 |
+
tasks.append(RefinementTask(img_a, img_b, loc_to, loc_from, area_a, area_b, converge_iters, zoom_ins))
|
167 |
+
else:
|
168 |
+
if force:
|
169 |
+
for i, loc_from in enumerate(queries_a):
|
170 |
+
pos = loc_from[::-1]
|
171 |
+
pos = np.array([np.clip(pos[0], 0, corr_a.shape[0] - 1), np.clip(pos[1], 0, corr_a.shape[1] - 1)], dtype=np.int)
|
172 |
+
loc_to = (corr_a[tuple(pos)].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
|
173 |
+
tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i))
|
174 |
+
else:
|
175 |
+
for i, loc_from in enumerate(queries_a):
|
176 |
+
pos = loc_from[::-1]
|
177 |
+
if (pos > np.array(img_a.shape[:2]) - 1).any() or (pos < 0).any():
|
178 |
+
continue
|
179 |
+
if mask_a[tuple(np.floor(pos).astype('int'))]:
|
180 |
+
loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
|
181 |
+
tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i))
|
182 |
+
if len(tasks) < max_corrs:
|
183 |
+
extra = max_corrs - len(tasks)
|
184 |
+
counter = 0
|
185 |
+
for i, loc_from in enumerate(queries_a):
|
186 |
+
if counter >= extra:
|
187 |
+
break
|
188 |
+
pos = loc_from[::-1]
|
189 |
+
if (pos > np.array(img_a.shape[:2]) - 1).any() or (pos < 0).any():
|
190 |
+
continue
|
191 |
+
if mask_a[tuple(np.floor(pos).astype('int'))] == False:
|
192 |
+
loc_to = (corr_a[tuple(np.floor(pos).astype('int'))].copy() * 0.5 + 0.5) * img_b.shape[:2][::-1]
|
193 |
+
tasks.append(RefinementTask(img_a, img_b, loc_from, loc_to, area_a, area_b, converge_iters, zoom_ins, identifier=i))
|
194 |
+
counter += 1
|
195 |
+
return tasks
|
196 |
+
|
197 |
+
def cotr_corr_multiscale(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, force=False, return_tasks_only=False, areas=None):
|
198 |
+
'''
|
199 |
+
currently only support fixed queries_a
|
200 |
+
'''
|
201 |
+
img_a = img_a.copy()
|
202 |
+
img_b = img_b.copy()
|
203 |
+
img_a_shape = img_a.shape[:2]
|
204 |
+
img_b_shape = img_b.shape[:2]
|
205 |
+
if queries_a is not None:
|
206 |
+
queries_a = queries_a.copy()
|
207 |
+
tasks = self.gen_tasks(img_a, img_b, zoom_ins, converge_iters, max_corrs, queries_a, force, areas)
|
208 |
+
while True:
|
209 |
+
num_g = self.num_good_tasks(tasks)
|
210 |
+
print(f'{num_g} / {max_corrs} | {self.num_finished_tasks(tasks)} / {len(tasks)}')
|
211 |
+
task_ref, img_batch, query_batch = self.form_batch(tasks)
|
212 |
+
if len(task_ref) == 0:
|
213 |
+
break
|
214 |
+
if num_g >= max_corrs:
|
215 |
+
break
|
216 |
+
out = self.infer_batch(img_batch, query_batch)
|
217 |
+
for t, o in zip(task_ref, out):
|
218 |
+
t.step(o)
|
219 |
+
if return_tasks_only:
|
220 |
+
return tasks
|
221 |
+
if return_idx:
|
222 |
+
corrs, idx = self.conclude_tasks(tasks, return_idx=True, force=force,
|
223 |
+
img_a_shape=img_a_shape,
|
224 |
+
img_b_shape=img_b_shape,)
|
225 |
+
corrs = corrs[:max_corrs]
|
226 |
+
idx = idx[:max_corrs]
|
227 |
+
return corrs, idx
|
228 |
+
else:
|
229 |
+
corrs = self.conclude_tasks(tasks, force=force,
|
230 |
+
img_a_shape=img_a_shape,
|
231 |
+
img_b_shape=img_b_shape,)
|
232 |
+
corrs = corrs[:max_corrs]
|
233 |
+
return corrs
|
234 |
+
|
235 |
+
def cotr_corr_multiscale_with_cycle_consistency(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, return_cycle_error=False):
|
236 |
+
EXTRACTION_RATE = 0.3
|
237 |
+
temp_max_corrs = int(max_corrs / EXTRACTION_RATE)
|
238 |
+
if queries_a is not None:
|
239 |
+
temp_max_corrs = min(temp_max_corrs, queries_a.shape[0])
|
240 |
+
queries_a = queries_a.copy()
|
241 |
+
corr_f, idx_f = self.cotr_corr_multiscale(img_a.copy(), img_b.copy(),
|
242 |
+
zoom_ins=zoom_ins,
|
243 |
+
converge_iters=converge_iters,
|
244 |
+
max_corrs=temp_max_corrs,
|
245 |
+
queries_a=queries_a,
|
246 |
+
return_idx=True)
|
247 |
+
assert corr_f.shape[0] > 0
|
248 |
+
corr_b, idx_b = self.cotr_corr_multiscale(img_b.copy(), img_a.copy(),
|
249 |
+
zoom_ins=zoom_ins,
|
250 |
+
converge_iters=converge_iters,
|
251 |
+
max_corrs=corr_f.shape[0],
|
252 |
+
queries_a=corr_f[:, 2:].copy(),
|
253 |
+
return_idx=True)
|
254 |
+
assert corr_b.shape[0] > 0
|
255 |
+
cycle_errors = np.linalg.norm(corr_f[idx_b][:, :2] - corr_b[:, 2:], axis=1)
|
256 |
+
order = np.argsort(cycle_errors)
|
257 |
+
out = [corr_f[idx_b][order][:max_corrs]]
|
258 |
+
if return_idx:
|
259 |
+
out.append(idx_f[idx_b][order][:max_corrs])
|
260 |
+
if return_cycle_error:
|
261 |
+
out.append(cycle_errors[order][:max_corrs])
|
262 |
+
if len(out) == 1:
|
263 |
+
out = out[0]
|
264 |
+
return out
|
265 |
+
|
266 |
+
|
267 |
+
class FasterSparseEngine(SparseEngine):
|
268 |
+
'''
|
269 |
+
search and merge nearby tasks to accelerate inference speed.
|
270 |
+
It will make spatial accuracy slightly worse.
|
271 |
+
'''
|
272 |
+
|
273 |
+
def __init__(self, model, batch_size, mode='stretching', max_load=256):
|
274 |
+
super().__init__(model, batch_size, mode=mode)
|
275 |
+
self.max_load = max_load
|
276 |
+
|
277 |
+
def infer_batch_grouped(self, img_batch, query_batch):
|
278 |
+
device = next(self.model.parameters()).device
|
279 |
+
img_batch = img_batch.to(device)
|
280 |
+
query_batch = query_batch.to(device)
|
281 |
+
out = self.model(img_batch, query_batch)['pred_corrs'].clone().detach().cpu().numpy()
|
282 |
+
return out
|
283 |
+
|
284 |
+
def get_tasks_map(self, zoom, tasks):
|
285 |
+
maps = []
|
286 |
+
ids = []
|
287 |
+
for i, t in enumerate(tasks):
|
288 |
+
if t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom:
|
289 |
+
t_info = t.peek()
|
290 |
+
point = np.concatenate([t_info['loc_from'], t_info['loc_to']])
|
291 |
+
maps.append(point)
|
292 |
+
ids.append(i)
|
293 |
+
return np.array(maps), np.array(ids)
|
294 |
+
|
295 |
+
def form_squad(self, zoom, pilot, pilot_id, tasks, tasks_map, task_ids, bookkeeping):
|
296 |
+
assert pilot.status == 'unfinished' and pilot.submitted == False and pilot.cur_zoom == zoom
|
297 |
+
SAFE_AREA = 0.5
|
298 |
+
pilot_info = pilot.peek()
|
299 |
+
pilot_from_center_x = pilot_info['patch_from'].x + pilot_info['patch_from'].w/2
|
300 |
+
pilot_from_center_y = pilot_info['patch_from'].y + pilot_info['patch_from'].h/2
|
301 |
+
pilot_from_left = pilot_from_center_x - pilot_info['patch_from'].w/2 * SAFE_AREA
|
302 |
+
pilot_from_right = pilot_from_center_x + pilot_info['patch_from'].w/2 * SAFE_AREA
|
303 |
+
pilot_from_upper = pilot_from_center_y - pilot_info['patch_from'].h/2 * SAFE_AREA
|
304 |
+
pilot_from_lower = pilot_from_center_y + pilot_info['patch_from'].h/2 * SAFE_AREA
|
305 |
+
|
306 |
+
pilot_to_center_x = pilot_info['patch_to'].x + pilot_info['patch_to'].w/2
|
307 |
+
pilot_to_center_y = pilot_info['patch_to'].y + pilot_info['patch_to'].h/2
|
308 |
+
pilot_to_left = pilot_to_center_x - pilot_info['patch_to'].w/2 * SAFE_AREA
|
309 |
+
pilot_to_right = pilot_to_center_x + pilot_info['patch_to'].w/2 * SAFE_AREA
|
310 |
+
pilot_to_upper = pilot_to_center_y - pilot_info['patch_to'].h/2 * SAFE_AREA
|
311 |
+
pilot_to_lower = pilot_to_center_y + pilot_info['patch_to'].h/2 * SAFE_AREA
|
312 |
+
|
313 |
+
img, query = pilot.get_task()
|
314 |
+
assert pilot.submitted == True
|
315 |
+
members = [pilot]
|
316 |
+
queries = [query]
|
317 |
+
bookkeeping[pilot_id] = False
|
318 |
+
|
319 |
+
loads = np.where(((tasks_map[:, 0] > pilot_from_left) &
|
320 |
+
(tasks_map[:, 0] < pilot_from_right) &
|
321 |
+
(tasks_map[:, 1] > pilot_from_upper) &
|
322 |
+
(tasks_map[:, 1] < pilot_from_lower) &
|
323 |
+
(tasks_map[:, 2] > pilot_to_left) &
|
324 |
+
(tasks_map[:, 2] < pilot_to_right) &
|
325 |
+
(tasks_map[:, 3] > pilot_to_upper) &
|
326 |
+
(tasks_map[:, 3] < pilot_to_lower)) *
|
327 |
+
bookkeeping)[0][: self.max_load]
|
328 |
+
|
329 |
+
for ti in task_ids[loads]:
|
330 |
+
t = tasks[ti]
|
331 |
+
assert t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom
|
332 |
+
_, query = t.get_task_pilot(pilot)
|
333 |
+
members.append(t)
|
334 |
+
queries.append(query)
|
335 |
+
queries = torch.stack(queries, axis=1)
|
336 |
+
bookkeeping[loads] = False
|
337 |
+
return members, img, queries, bookkeeping
|
338 |
+
|
339 |
+
def form_grouped_batch(self, zoom, tasks):
|
340 |
+
counter = 0
|
341 |
+
task_ref = []
|
342 |
+
img_batch = []
|
343 |
+
query_batch = []
|
344 |
+
tasks_map, task_ids = self.get_tasks_map(zoom, tasks)
|
345 |
+
shuffle = np.random.permutation(tasks_map.shape[0])
|
346 |
+
tasks_map = np.take(tasks_map, shuffle, axis=0)
|
347 |
+
task_ids = np.take(task_ids, shuffle, axis=0)
|
348 |
+
bookkeeping = np.ones_like(task_ids).astype(bool)
|
349 |
+
|
350 |
+
for i, ti in enumerate(task_ids):
|
351 |
+
t = tasks[ti]
|
352 |
+
if t.status == 'unfinished' and t.submitted == False and t.cur_zoom == zoom:
|
353 |
+
members, img, queries, bookkeeping = self.form_squad(zoom, t, i, tasks, tasks_map, task_ids, bookkeeping)
|
354 |
+
task_ref.append(members)
|
355 |
+
img_batch.append(img)
|
356 |
+
query_batch.append(queries)
|
357 |
+
counter += 1
|
358 |
+
if counter >= self.batch_size:
|
359 |
+
break
|
360 |
+
if len(task_ref) == 0:
|
361 |
+
return [], [], []
|
362 |
+
|
363 |
+
max_len = max([q.shape[1] for q in query_batch])
|
364 |
+
for i in range(len(query_batch)):
|
365 |
+
q = query_batch[i]
|
366 |
+
query_batch[i] = torch.cat([q, torch.zeros([1, max_len - q.shape[1], 2])], axis=1)
|
367 |
+
img_batch = torch.stack(img_batch)
|
368 |
+
query_batch = torch.cat(query_batch)
|
369 |
+
return task_ref, img_batch, query_batch
|
370 |
+
|
371 |
+
def cotr_corr_multiscale(self, img_a, img_b, zoom_ins=[1.0], converge_iters=1, max_corrs=1000, queries_a=None, return_idx=False, force=False, return_tasks_only=False, areas=None):
|
372 |
+
'''
|
373 |
+
currently only support fixed queries_a
|
374 |
+
'''
|
375 |
+
img_a = img_a.copy()
|
376 |
+
img_b = img_b.copy()
|
377 |
+
img_a_shape = img_a.shape[:2]
|
378 |
+
img_b_shape = img_b.shape[:2]
|
379 |
+
if queries_a is not None:
|
380 |
+
queries_a = queries_a.copy()
|
381 |
+
tasks = self.gen_tasks(img_a, img_b, zoom_ins, converge_iters, max_corrs, queries_a, force, areas)
|
382 |
+
for zm in zoom_ins:
|
383 |
+
print(f'======= Zoom: {zm} ======')
|
384 |
+
while True:
|
385 |
+
num_g = self.num_good_tasks(tasks)
|
386 |
+
task_ref, img_batch, query_batch = self.form_grouped_batch(zm, tasks)
|
387 |
+
if len(task_ref) == 0:
|
388 |
+
break
|
389 |
+
if num_g >= max_corrs:
|
390 |
+
break
|
391 |
+
out = self.infer_batch_grouped(img_batch, query_batch)
|
392 |
+
num_steps = 0
|
393 |
+
for i, temp in enumerate(task_ref):
|
394 |
+
for j, t in enumerate(temp):
|
395 |
+
t.step(out[i, j])
|
396 |
+
num_steps += 1
|
397 |
+
print(f'solved {num_steps} sub-tasks in one invocation with {img_batch.shape[0]} image pairs')
|
398 |
+
if num_steps <= self.batch_size:
|
399 |
+
break
|
400 |
+
# Rollback to default inference, because of too few valid tasks can be grouped together.
|
401 |
+
while True:
|
402 |
+
num_g = self.num_good_tasks(tasks)
|
403 |
+
print(f'{num_g} / {max_corrs} | {self.num_finished_tasks(tasks)} / {len(tasks)}')
|
404 |
+
task_ref, img_batch, query_batch = self.form_batch(tasks, zm)
|
405 |
+
if len(task_ref) == 0:
|
406 |
+
break
|
407 |
+
if num_g >= max_corrs:
|
408 |
+
break
|
409 |
+
out = self.infer_batch(img_batch, query_batch)
|
410 |
+
for t, o in zip(task_ref, out):
|
411 |
+
t.step(o)
|
412 |
+
|
413 |
+
if return_tasks_only:
|
414 |
+
return tasks
|
415 |
+
if return_idx:
|
416 |
+
corrs, idx = self.conclude_tasks(tasks, return_idx=True, force=force,
|
417 |
+
img_a_shape=img_a_shape,
|
418 |
+
img_b_shape=img_b_shape,)
|
419 |
+
corrs = corrs[:max_corrs]
|
420 |
+
idx = idx[:max_corrs]
|
421 |
+
return corrs, idx
|
422 |
+
else:
|
423 |
+
corrs = self.conclude_tasks(tasks, force=force,
|
424 |
+
img_a_shape=img_a_shape,
|
425 |
+
img_b_shape=img_b_shape,)
|
426 |
+
corrs = corrs[:max_corrs]
|
427 |
+
return corrs
|
third_party/COTR/COTR/models/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
The COTR model is modified from DETR code base.
|
3 |
+
https://github.com/facebookresearch/detr
|
4 |
+
'''
|
5 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
6 |
+
from .cotr_model import build
|
7 |
+
|
8 |
+
|
9 |
+
def build_model(args):
|
10 |
+
return build(args)
|
third_party/COTR/COTR/models/backbone.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Backbone modules.
|
4 |
+
"""
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision
|
10 |
+
from torch import nn
|
11 |
+
from torchvision.models._utils import IntermediateLayerGetter
|
12 |
+
from typing import Dict, List
|
13 |
+
|
14 |
+
from .misc import NestedTensor
|
15 |
+
|
16 |
+
from .position_encoding import build_position_encoding
|
17 |
+
from COTR.utils import debug_utils, constants
|
18 |
+
|
19 |
+
|
20 |
+
class FrozenBatchNorm2d(torch.nn.Module):
|
21 |
+
"""
|
22 |
+
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
23 |
+
|
24 |
+
Copy-paste from torchvision.misc.ops with added eps before rqsrt,
|
25 |
+
without which any other models than torchvision.models.resnet[18,34,50,101]
|
26 |
+
produce nans.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self, n):
|
30 |
+
super(FrozenBatchNorm2d, self).__init__()
|
31 |
+
self.register_buffer("weight", torch.ones(n))
|
32 |
+
self.register_buffer("bias", torch.zeros(n))
|
33 |
+
self.register_buffer("running_mean", torch.zeros(n))
|
34 |
+
self.register_buffer("running_var", torch.ones(n))
|
35 |
+
|
36 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
37 |
+
missing_keys, unexpected_keys, error_msgs):
|
38 |
+
num_batches_tracked_key = prefix + 'num_batches_tracked'
|
39 |
+
if num_batches_tracked_key in state_dict:
|
40 |
+
del state_dict[num_batches_tracked_key]
|
41 |
+
|
42 |
+
super(FrozenBatchNorm2d, self)._load_from_state_dict(
|
43 |
+
state_dict, prefix, local_metadata, strict,
|
44 |
+
missing_keys, unexpected_keys, error_msgs)
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
# move reshapes to the beginning
|
48 |
+
# to make it fuser-friendly
|
49 |
+
w = self.weight.reshape(1, -1, 1, 1)
|
50 |
+
b = self.bias.reshape(1, -1, 1, 1)
|
51 |
+
rv = self.running_var.reshape(1, -1, 1, 1)
|
52 |
+
rm = self.running_mean.reshape(1, -1, 1, 1)
|
53 |
+
eps = 1e-5
|
54 |
+
scale = w * (rv + eps).rsqrt()
|
55 |
+
bias = b - rm * scale
|
56 |
+
return x * scale + bias
|
57 |
+
|
58 |
+
|
59 |
+
class BackboneBase(nn.Module):
|
60 |
+
|
61 |
+
def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool, layer='layer3'):
|
62 |
+
super().__init__()
|
63 |
+
for name, parameter in backbone.named_parameters():
|
64 |
+
if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
|
65 |
+
parameter.requires_grad_(False)
|
66 |
+
# print(f'freeze {name}')
|
67 |
+
if return_interm_layers:
|
68 |
+
return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
|
69 |
+
else:
|
70 |
+
return_layers = {layer: "0"}
|
71 |
+
self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
|
72 |
+
self.num_channels = num_channels
|
73 |
+
|
74 |
+
def forward_raw(self, x):
|
75 |
+
y = self.body(x)
|
76 |
+
assert len(y.keys()) == 1
|
77 |
+
return y['0']
|
78 |
+
|
79 |
+
def forward(self, tensor_list: NestedTensor):
|
80 |
+
assert tensor_list.tensors.shape[-2:] == (constants.MAX_SIZE, constants.MAX_SIZE * 2)
|
81 |
+
left = self.body(tensor_list.tensors[..., 0:constants.MAX_SIZE])
|
82 |
+
right = self.body(tensor_list.tensors[..., constants.MAX_SIZE:2 * constants.MAX_SIZE])
|
83 |
+
xs = {}
|
84 |
+
for k in left.keys():
|
85 |
+
xs[k] = torch.cat([left[k], right[k]], dim=-1)
|
86 |
+
out: Dict[str, NestedTensor] = {}
|
87 |
+
for name, x in xs.items():
|
88 |
+
m = tensor_list.mask
|
89 |
+
assert m is not None
|
90 |
+
mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
|
91 |
+
out[name] = NestedTensor(x, mask)
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class Backbone(BackboneBase):
|
96 |
+
"""ResNet backbone with frozen BatchNorm."""
|
97 |
+
|
98 |
+
def __init__(self, name: str,
|
99 |
+
train_backbone: bool,
|
100 |
+
return_interm_layers: bool,
|
101 |
+
dilation: bool,
|
102 |
+
layer='layer3',
|
103 |
+
num_channels=1024):
|
104 |
+
backbone = getattr(torchvision.models, name)(
|
105 |
+
replace_stride_with_dilation=[False, False, dilation],
|
106 |
+
pretrained=True, norm_layer=FrozenBatchNorm2d)
|
107 |
+
super().__init__(backbone, train_backbone, num_channels, return_interm_layers, layer)
|
108 |
+
|
109 |
+
|
110 |
+
class Joiner(nn.Sequential):
|
111 |
+
def __init__(self, backbone, position_embedding):
|
112 |
+
super().__init__(backbone, position_embedding)
|
113 |
+
|
114 |
+
def forward(self, tensor_list: NestedTensor):
|
115 |
+
xs = self[0](tensor_list)
|
116 |
+
out: List[NestedTensor] = []
|
117 |
+
pos = []
|
118 |
+
for name, x in xs.items():
|
119 |
+
out.append(x)
|
120 |
+
# position encoding
|
121 |
+
pos.append(self[1](x).to(x.tensors.dtype))
|
122 |
+
|
123 |
+
return out, pos
|
124 |
+
|
125 |
+
|
126 |
+
def build_backbone(args):
|
127 |
+
position_embedding = build_position_encoding(args)
|
128 |
+
if hasattr(args, 'lr_backbone'):
|
129 |
+
train_backbone = args.lr_backbone > 0
|
130 |
+
else:
|
131 |
+
train_backbone = False
|
132 |
+
backbone = Backbone(args.backbone, train_backbone, False, args.dilation, layer=args.layer, num_channels=args.dim_feedforward)
|
133 |
+
model = Joiner(backbone, position_embedding)
|
134 |
+
model.num_channels = backbone.num_channels
|
135 |
+
return model
|
third_party/COTR/COTR/models/cotr_model.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from COTR.utils import debug_utils, constants, utils
|
9 |
+
from .misc import (NestedTensor, nested_tensor_from_tensor_list)
|
10 |
+
from .backbone import build_backbone
|
11 |
+
from .transformer import build_transformer
|
12 |
+
from .position_encoding import NerfPositionalEncoding, MLP
|
13 |
+
|
14 |
+
|
15 |
+
class COTR(nn.Module):
|
16 |
+
|
17 |
+
def __init__(self, backbone, transformer, sine_type='lin_sine'):
|
18 |
+
super().__init__()
|
19 |
+
self.transformer = transformer
|
20 |
+
hidden_dim = transformer.d_model
|
21 |
+
self.corr_embed = MLP(hidden_dim, hidden_dim, 2, 3)
|
22 |
+
self.query_proj = NerfPositionalEncoding(hidden_dim // 4, sine_type)
|
23 |
+
self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
|
24 |
+
self.backbone = backbone
|
25 |
+
|
26 |
+
def forward(self, samples: NestedTensor, queries):
|
27 |
+
if isinstance(samples, (list, torch.Tensor)):
|
28 |
+
samples = nested_tensor_from_tensor_list(samples)
|
29 |
+
features, pos = self.backbone(samples)
|
30 |
+
|
31 |
+
src, mask = features[-1].decompose()
|
32 |
+
assert mask is not None
|
33 |
+
_b, _q, _ = queries.shape
|
34 |
+
queries = queries.reshape(-1, 2)
|
35 |
+
queries = self.query_proj(queries).reshape(_b, _q, -1)
|
36 |
+
queries = queries.permute(1, 0, 2)
|
37 |
+
hs = self.transformer(self.input_proj(src), mask, queries, pos[-1])[0]
|
38 |
+
outputs_corr = self.corr_embed(hs)
|
39 |
+
out = {'pred_corrs': outputs_corr[-1]}
|
40 |
+
return out
|
41 |
+
|
42 |
+
|
43 |
+
def build(args):
|
44 |
+
backbone = build_backbone(args)
|
45 |
+
transformer = build_transformer(args)
|
46 |
+
model = COTR(
|
47 |
+
backbone,
|
48 |
+
transformer,
|
49 |
+
sine_type=args.position_embedding,
|
50 |
+
)
|
51 |
+
return model
|
third_party/COTR/COTR/models/misc.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Misc functions, including distributed helpers.
|
4 |
+
|
5 |
+
Mostly copy-paste from torchvision references.
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
import subprocess
|
9 |
+
import time
|
10 |
+
from collections import defaultdict, deque
|
11 |
+
import datetime
|
12 |
+
import pickle
|
13 |
+
from typing import Optional, List
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.distributed as dist
|
17 |
+
from torch import Tensor
|
18 |
+
|
19 |
+
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
20 |
+
import torchvision
|
21 |
+
if float(torchvision.__version__.split('.')[1]) < 7:
|
22 |
+
from torchvision.ops import _new_empty_tensor
|
23 |
+
from torchvision.ops.misc import _output_size
|
24 |
+
|
25 |
+
|
26 |
+
def _max_by_axis(the_list):
|
27 |
+
# type: (List[List[int]]) -> List[int]
|
28 |
+
maxes = the_list[0]
|
29 |
+
for sublist in the_list[1:]:
|
30 |
+
for index, item in enumerate(sublist):
|
31 |
+
maxes[index] = max(maxes[index], item)
|
32 |
+
return maxes
|
33 |
+
|
34 |
+
|
35 |
+
class NestedTensor(object):
|
36 |
+
def __init__(self, tensors, mask: Optional[Tensor]):
|
37 |
+
self.tensors = tensors
|
38 |
+
self.mask = mask
|
39 |
+
|
40 |
+
def to(self, device):
|
41 |
+
# type: (Device) -> NestedTensor # noqa
|
42 |
+
cast_tensor = self.tensors.to(device)
|
43 |
+
mask = self.mask
|
44 |
+
if mask is not None:
|
45 |
+
assert mask is not None
|
46 |
+
cast_mask = mask.to(device)
|
47 |
+
else:
|
48 |
+
cast_mask = None
|
49 |
+
return NestedTensor(cast_tensor, cast_mask)
|
50 |
+
|
51 |
+
def decompose(self):
|
52 |
+
return self.tensors, self.mask
|
53 |
+
|
54 |
+
def __repr__(self):
|
55 |
+
return str(self.tensors)
|
56 |
+
|
57 |
+
|
58 |
+
def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
59 |
+
# TODO make this more general
|
60 |
+
if tensor_list[0].ndim == 3:
|
61 |
+
if torchvision._is_tracing():
|
62 |
+
# nested_tensor_from_tensor_list() does not export well to ONNX
|
63 |
+
# call _onnx_nested_tensor_from_tensor_list() instead
|
64 |
+
return _onnx_nested_tensor_from_tensor_list(tensor_list)
|
65 |
+
|
66 |
+
# TODO make it support different-sized images
|
67 |
+
max_size = _max_by_axis([list(img.shape) for img in tensor_list])
|
68 |
+
# min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
|
69 |
+
batch_shape = [len(tensor_list)] + max_size
|
70 |
+
b, c, h, w = batch_shape
|
71 |
+
dtype = tensor_list[0].dtype
|
72 |
+
device = tensor_list[0].device
|
73 |
+
tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
|
74 |
+
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
75 |
+
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
76 |
+
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
77 |
+
m[: img.shape[1], :img.shape[2]] = False
|
78 |
+
else:
|
79 |
+
raise ValueError('not supported')
|
80 |
+
return NestedTensor(tensor, mask)
|
81 |
+
|
82 |
+
|
83 |
+
# _onnx_nested_tensor_from_tensor_list() is an implementation of
|
84 |
+
# nested_tensor_from_tensor_list() that is supported by ONNX tracing.
|
85 |
+
@torch.jit.unused
|
86 |
+
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
87 |
+
max_size = []
|
88 |
+
for i in range(tensor_list[0].dim()):
|
89 |
+
max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
|
90 |
+
max_size.append(max_size_i)
|
91 |
+
max_size = tuple(max_size)
|
92 |
+
|
93 |
+
# work around for
|
94 |
+
# pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
95 |
+
# m[: img.shape[1], :img.shape[2]] = False
|
96 |
+
# which is not yet supported in onnx
|
97 |
+
padded_imgs = []
|
98 |
+
padded_masks = []
|
99 |
+
for img in tensor_list:
|
100 |
+
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
101 |
+
padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
|
102 |
+
padded_imgs.append(padded_img)
|
103 |
+
|
104 |
+
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
105 |
+
padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
|
106 |
+
padded_masks.append(padded_mask.to(torch.bool))
|
107 |
+
|
108 |
+
tensor = torch.stack(padded_imgs)
|
109 |
+
mask = torch.stack(padded_masks)
|
110 |
+
|
111 |
+
return NestedTensor(tensor, mask=mask)
|
112 |
+
|
third_party/COTR/COTR/models/position_encoding.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
Various positional encodings for the transformer.
|
4 |
+
"""
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from .misc import NestedTensor
|
11 |
+
from COTR.utils import debug_utils
|
12 |
+
|
13 |
+
|
14 |
+
class MLP(nn.Module):
|
15 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
16 |
+
|
17 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
18 |
+
super().__init__()
|
19 |
+
self.num_layers = num_layers
|
20 |
+
h = [hidden_dim] * (num_layers - 1)
|
21 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
for i, layer in enumerate(self.layers):
|
25 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
class NerfPositionalEncoding(nn.Module):
|
30 |
+
def __init__(self, depth=10, sine_type='lin_sine'):
|
31 |
+
'''
|
32 |
+
out_dim = in_dim * depth * 2
|
33 |
+
'''
|
34 |
+
super().__init__()
|
35 |
+
if sine_type == 'lin_sine':
|
36 |
+
self.bases = [i+1 for i in range(depth)]
|
37 |
+
elif sine_type == 'exp_sine':
|
38 |
+
self.bases = [2**i for i in range(depth)]
|
39 |
+
print(f'using {sine_type} as positional encoding')
|
40 |
+
|
41 |
+
@torch.no_grad()
|
42 |
+
def forward(self, inputs):
|
43 |
+
out = torch.cat([torch.sin(i * math.pi * inputs) for i in self.bases] + [torch.cos(i * math.pi * inputs) for i in self.bases], axis=-1)
|
44 |
+
assert torch.isnan(out).any() == False
|
45 |
+
return out
|
46 |
+
|
47 |
+
|
48 |
+
class PositionEmbeddingSine(nn.Module):
|
49 |
+
"""
|
50 |
+
This is a more standard version of the position embedding, very similar to the one
|
51 |
+
used by the Attention is all you need paper, generalized to work on images.
|
52 |
+
"""
|
53 |
+
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None, sine_type='lin_sine'):
|
54 |
+
super().__init__()
|
55 |
+
self.num_pos_feats = num_pos_feats
|
56 |
+
self.temperature = temperature
|
57 |
+
self.normalize = normalize
|
58 |
+
self.sine = NerfPositionalEncoding(num_pos_feats//2, sine_type)
|
59 |
+
|
60 |
+
@torch.no_grad()
|
61 |
+
def forward(self, tensor_list: NestedTensor):
|
62 |
+
x = tensor_list.tensors
|
63 |
+
mask = tensor_list.mask
|
64 |
+
assert mask is not None
|
65 |
+
not_mask = ~mask
|
66 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
67 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
68 |
+
eps = 1e-6
|
69 |
+
y_embed = (y_embed-0.5) / (y_embed[:, -1:, :] + eps)
|
70 |
+
x_embed = (x_embed-0.5) / (x_embed[:, :, -1:] + eps)
|
71 |
+
pos = torch.stack([x_embed, y_embed], dim=-1)
|
72 |
+
return self.sine(pos).permute(0, 3, 1, 2)
|
73 |
+
|
74 |
+
|
75 |
+
def build_position_encoding(args):
|
76 |
+
N_steps = args.hidden_dim // 2
|
77 |
+
if args.position_embedding in ('lin_sine', 'exp_sine'):
|
78 |
+
# TODO find a better way of exposing other arguments
|
79 |
+
position_embedding = PositionEmbeddingSine(N_steps, normalize=True, sine_type=args.position_embedding)
|
80 |
+
else:
|
81 |
+
raise ValueError(f"not supported {args.position_embedding}")
|
82 |
+
|
83 |
+
return position_embedding
|
third_party/COTR/COTR/models/transformer.py
ADDED
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
2 |
+
"""
|
3 |
+
COTR/DETR Transformer class.
|
4 |
+
|
5 |
+
Copy-paste from torch.nn.Transformer with modifications:
|
6 |
+
* positional encodings are passed in MHattention
|
7 |
+
* extra LN at the end of encoder is removed
|
8 |
+
* decoder returns a stack of activations from all decoding layers
|
9 |
+
"""
|
10 |
+
import copy
|
11 |
+
from typing import Optional, List
|
12 |
+
|
13 |
+
import torch
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
from COTR.utils import debug_utils
|
18 |
+
|
19 |
+
|
20 |
+
class Transformer(nn.Module):
|
21 |
+
|
22 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
|
23 |
+
num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
|
24 |
+
activation="relu", return_intermediate_dec=False):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
28 |
+
dropout, activation)
|
29 |
+
self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers)
|
30 |
+
|
31 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
32 |
+
dropout, activation)
|
33 |
+
decoder_norm = nn.LayerNorm(d_model)
|
34 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
35 |
+
return_intermediate=return_intermediate_dec)
|
36 |
+
|
37 |
+
self._reset_parameters()
|
38 |
+
|
39 |
+
self.d_model = d_model
|
40 |
+
self.nhead = nhead
|
41 |
+
|
42 |
+
def _reset_parameters(self):
|
43 |
+
for p in self.parameters():
|
44 |
+
if p.dim() > 1:
|
45 |
+
nn.init.xavier_uniform_(p)
|
46 |
+
|
47 |
+
def forward(self, src, mask, query_embed, pos_embed):
|
48 |
+
# flatten NxCxHxW to HWxNxC
|
49 |
+
bs, c, h, w = src.shape
|
50 |
+
src = src.flatten(2).permute(2, 0, 1)
|
51 |
+
pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
|
52 |
+
mask = mask.flatten(1)
|
53 |
+
|
54 |
+
tgt = torch.zeros_like(query_embed)
|
55 |
+
memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
|
56 |
+
hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
|
57 |
+
pos=pos_embed, query_pos=query_embed)
|
58 |
+
return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)
|
59 |
+
|
60 |
+
|
61 |
+
class TransformerEncoder(nn.Module):
|
62 |
+
|
63 |
+
def __init__(self, encoder_layer, num_layers):
|
64 |
+
super().__init__()
|
65 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
66 |
+
self.num_layers = num_layers
|
67 |
+
|
68 |
+
def forward(self, src,
|
69 |
+
mask: Optional[Tensor] = None,
|
70 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
71 |
+
pos: Optional[Tensor] = None):
|
72 |
+
output = src
|
73 |
+
|
74 |
+
for layer in self.layers:
|
75 |
+
output = layer(output, src_mask=mask,
|
76 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
77 |
+
|
78 |
+
return output
|
79 |
+
|
80 |
+
|
81 |
+
class TransformerDecoder(nn.Module):
|
82 |
+
|
83 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
84 |
+
super().__init__()
|
85 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
86 |
+
self.num_layers = num_layers
|
87 |
+
self.norm = norm
|
88 |
+
self.return_intermediate = return_intermediate
|
89 |
+
|
90 |
+
def forward(self, tgt, memory,
|
91 |
+
tgt_mask: Optional[Tensor] = None,
|
92 |
+
memory_mask: Optional[Tensor] = None,
|
93 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
94 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
95 |
+
pos: Optional[Tensor] = None,
|
96 |
+
query_pos: Optional[Tensor] = None):
|
97 |
+
output = tgt
|
98 |
+
|
99 |
+
intermediate = []
|
100 |
+
|
101 |
+
for layer in self.layers:
|
102 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
103 |
+
memory_mask=memory_mask,
|
104 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
105 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
106 |
+
pos=pos, query_pos=query_pos)
|
107 |
+
if self.return_intermediate:
|
108 |
+
intermediate.append(self.norm(output))
|
109 |
+
|
110 |
+
if self.norm is not None:
|
111 |
+
output = self.norm(output)
|
112 |
+
if self.return_intermediate:
|
113 |
+
intermediate.pop()
|
114 |
+
intermediate.append(output)
|
115 |
+
|
116 |
+
if self.return_intermediate:
|
117 |
+
return torch.stack(intermediate)
|
118 |
+
|
119 |
+
return output.unsqueeze(0)
|
120 |
+
|
121 |
+
|
122 |
+
class TransformerEncoderLayer(nn.Module):
|
123 |
+
|
124 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
125 |
+
activation="relu"):
|
126 |
+
super().__init__()
|
127 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
128 |
+
# Implementation of Feedforward model
|
129 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
130 |
+
self.dropout = nn.Dropout(dropout)
|
131 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
132 |
+
|
133 |
+
self.norm1 = nn.LayerNorm(d_model)
|
134 |
+
self.norm2 = nn.LayerNorm(d_model)
|
135 |
+
self.dropout1 = nn.Dropout(dropout)
|
136 |
+
self.dropout2 = nn.Dropout(dropout)
|
137 |
+
|
138 |
+
self.activation = _get_activation_fn(activation)
|
139 |
+
|
140 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
141 |
+
return tensor if pos is None else tensor + pos
|
142 |
+
|
143 |
+
def forward(self,
|
144 |
+
src,
|
145 |
+
src_mask: Optional[Tensor] = None,
|
146 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
147 |
+
pos: Optional[Tensor] = None):
|
148 |
+
q = k = self.with_pos_embed(src, pos)
|
149 |
+
src2 = self.self_attn(query=q,
|
150 |
+
key=k,
|
151 |
+
value=src,
|
152 |
+
attn_mask=src_mask,
|
153 |
+
key_padding_mask=src_key_padding_mask)[0]
|
154 |
+
src = src + self.dropout1(src2)
|
155 |
+
src = self.norm1(src)
|
156 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
157 |
+
src = src + self.dropout2(src2)
|
158 |
+
src = self.norm2(src)
|
159 |
+
return src
|
160 |
+
|
161 |
+
|
162 |
+
class TransformerDecoderLayer(nn.Module):
|
163 |
+
|
164 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
165 |
+
activation="relu"):
|
166 |
+
super().__init__()
|
167 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
168 |
+
# Implementation of Feedforward model
|
169 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
170 |
+
self.dropout = nn.Dropout(dropout)
|
171 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
172 |
+
|
173 |
+
self.norm1 = nn.LayerNorm(d_model)
|
174 |
+
self.norm2 = nn.LayerNorm(d_model)
|
175 |
+
self.norm3 = nn.LayerNorm(d_model)
|
176 |
+
self.dropout1 = nn.Dropout(dropout)
|
177 |
+
self.dropout2 = nn.Dropout(dropout)
|
178 |
+
self.dropout3 = nn.Dropout(dropout)
|
179 |
+
|
180 |
+
self.activation = _get_activation_fn(activation)
|
181 |
+
|
182 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
183 |
+
return tensor if pos is None else tensor + pos
|
184 |
+
|
185 |
+
def forward(self, tgt, memory,
|
186 |
+
tgt_mask: Optional[Tensor] = None,
|
187 |
+
memory_mask: Optional[Tensor] = None,
|
188 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
189 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
190 |
+
pos: Optional[Tensor] = None,
|
191 |
+
query_pos: Optional[Tensor] = None):
|
192 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
193 |
+
key=self.with_pos_embed(memory, pos),
|
194 |
+
value=memory, attn_mask=memory_mask,
|
195 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
196 |
+
tgt = tgt + self.dropout2(tgt2)
|
197 |
+
tgt = self.norm2(tgt)
|
198 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
199 |
+
tgt = tgt + self.dropout3(tgt2)
|
200 |
+
tgt = self.norm3(tgt)
|
201 |
+
return tgt
|
202 |
+
|
203 |
+
|
204 |
+
def _get_clones(module, N):
|
205 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
206 |
+
|
207 |
+
|
208 |
+
def build_transformer(args):
|
209 |
+
return Transformer(
|
210 |
+
d_model=args.hidden_dim,
|
211 |
+
dropout=args.dropout,
|
212 |
+
nhead=args.nheads,
|
213 |
+
dim_feedforward=args.dim_feedforward,
|
214 |
+
num_encoder_layers=args.enc_layers,
|
215 |
+
num_decoder_layers=args.dec_layers,
|
216 |
+
return_intermediate_dec=True,
|
217 |
+
)
|
218 |
+
|
219 |
+
|
220 |
+
def _get_activation_fn(activation):
|
221 |
+
"""Return an activation function given a string"""
|
222 |
+
if activation == "relu":
|
223 |
+
return F.relu
|
224 |
+
if activation == "gelu":
|
225 |
+
return F.gelu
|
226 |
+
if activation == "glu":
|
227 |
+
return F.glu
|
228 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
third_party/COTR/COTR/options/options.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
|
6 |
+
|
7 |
+
from COTR.options.options_utils import str2bool
|
8 |
+
from COTR.options import options_utils
|
9 |
+
from COTR.global_configs import general_config, dataset_config
|
10 |
+
from COTR.utils import debug_utils
|
11 |
+
|
12 |
+
|
13 |
+
def set_general_arguments(parser):
|
14 |
+
general_arg = parser.add_argument_group('General')
|
15 |
+
general_arg.add_argument('--confirm', type=str2bool,
|
16 |
+
default=True, help='promote confirmation for user')
|
17 |
+
general_arg.add_argument('--use_cuda', type=str2bool,
|
18 |
+
default=True, help='use cuda')
|
19 |
+
general_arg.add_argument('--use_cc', type=str2bool,
|
20 |
+
default=False, help='use computecanada')
|
21 |
+
|
22 |
+
|
23 |
+
def set_dataset_arguments(parser):
|
24 |
+
data_arg = parser.add_argument_group('Data')
|
25 |
+
data_arg.add_argument('--dataset_name', type=str, default='megadepth', help='dataset name')
|
26 |
+
data_arg.add_argument('--shuffle_data', type=str2bool, default=True, help='use sequence dataset or shuffled dataset')
|
27 |
+
data_arg.add_argument('--use_ram', type=str2bool, default=False, help='load image/depth/pcd to ram')
|
28 |
+
data_arg.add_argument('--info_level', choices=['rgb', 'rgbd'], type=str, default='rgbd', help='the information level of dataset')
|
29 |
+
data_arg.add_argument('--scene_file', type=str, default=None, required=False, help='what scene/seq want to use')
|
30 |
+
data_arg.add_argument('--workers', type=int, default=0, help='worker for loading data')
|
31 |
+
data_arg.add_argument('--crop_cam', choices=['no_crop', 'crop_center', 'crop_center_and_resize'], type=str, default='crop_center_and_resize', help='crop the center of image to avoid changing aspect ratio, resize to make the operations batch-able.')
|
32 |
+
|
33 |
+
|
34 |
+
def set_nn_arguments(parser):
|
35 |
+
nn_arg = parser.add_argument_group('Nearest neighbors')
|
36 |
+
nn_arg.add_argument('--nn_method', choices=['netvlad', 'overlapping'], type=str, default='overlapping', help='how to select nearest neighbors')
|
37 |
+
nn_arg.add_argument('--pool_size', type=int, default=20, help='a pool of sorted nn candidates')
|
38 |
+
nn_arg.add_argument('--k_size', type=int, default=1, help='select the nn randomly from pool')
|
39 |
+
|
40 |
+
|
41 |
+
def set_COTR_arguments(parser):
|
42 |
+
cotr_arg = parser.add_argument_group('COTR model')
|
43 |
+
cotr_arg.add_argument('--backbone', type=str, default='resnet50')
|
44 |
+
cotr_arg.add_argument('--hidden_dim', type=int, default=256)
|
45 |
+
cotr_arg.add_argument('--dilation', type=str2bool, default=False)
|
46 |
+
cotr_arg.add_argument('--dropout', type=float, default=0.1)
|
47 |
+
cotr_arg.add_argument('--nheads', type=int, default=8)
|
48 |
+
cotr_arg.add_argument('--layer', type=str, default='layer3', help='which layer from resnet')
|
49 |
+
cotr_arg.add_argument('--enc_layers', type=int, default=6)
|
50 |
+
cotr_arg.add_argument('--dec_layers', type=int, default=6)
|
51 |
+
cotr_arg.add_argument('--position_embedding', type=str, default='lin_sine', help='sine wave type')
|
52 |
+
|
third_party/COTR/COTR/options/options_utils.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''utils for argparse
|
2 |
+
'''
|
3 |
+
|
4 |
+
import sys
|
5 |
+
import os
|
6 |
+
from os import path
|
7 |
+
import time
|
8 |
+
import json
|
9 |
+
|
10 |
+
from COTR.utils import utils, debug_utils
|
11 |
+
from COTR.global_configs import general_config, dataset_config
|
12 |
+
|
13 |
+
|
14 |
+
def str2bool(v: str) -> bool:
|
15 |
+
return v.lower() in ('true', '1', 'yes', 'y', 't')
|
16 |
+
|
17 |
+
|
18 |
+
def get_compact_naming_cotr(opt) -> str:
|
19 |
+
base_str = 'model:cotr_{0}_{1}_{2}_dset:{3}_bs:{4}_pe:{5}_lrbackbone:{6}'
|
20 |
+
result = base_str.format(opt.backbone,
|
21 |
+
opt.layer,
|
22 |
+
opt.dim_feedforward,
|
23 |
+
opt.dataset_name,
|
24 |
+
opt.batch_size,
|
25 |
+
opt.position_embedding,
|
26 |
+
opt.lr_backbone,
|
27 |
+
)
|
28 |
+
if opt.suffix:
|
29 |
+
result = result + '_suffix:{0}'.format(opt.suffix)
|
30 |
+
return result
|
31 |
+
|
32 |
+
|
33 |
+
def print_opt(opt):
|
34 |
+
content_list = []
|
35 |
+
args = list(vars(opt))
|
36 |
+
args.sort()
|
37 |
+
for arg in args:
|
38 |
+
content_list += [arg.rjust(25, ' ') + ' ' + str(getattr(opt, arg))]
|
39 |
+
utils.print_notification(content_list, 'OPTIONS')
|
40 |
+
|
41 |
+
|
42 |
+
def confirm_opt(opt):
|
43 |
+
print_opt(opt)
|
44 |
+
if opt.use_cc == False:
|
45 |
+
if not utils.confirm():
|
46 |
+
exit(1)
|
47 |
+
|
48 |
+
|
49 |
+
def opt_to_string(opt) -> str:
|
50 |
+
string = '\n\n'
|
51 |
+
string += 'python ' + ' '.join(sys.argv)
|
52 |
+
string += '\n\n'
|
53 |
+
# string += '---------------------- CONFIG ----------------------\n'
|
54 |
+
args = list(vars(opt))
|
55 |
+
args.sort()
|
56 |
+
for arg in args:
|
57 |
+
string += arg.rjust(25, ' ') + ' ' + str(getattr(opt, arg)) + '\n\n'
|
58 |
+
# string += '----------------------------------------------------\n'
|
59 |
+
return string
|
60 |
+
|
61 |
+
|
62 |
+
def save_opt(opt):
|
63 |
+
'''save options to a json file
|
64 |
+
'''
|
65 |
+
if not os.path.exists(opt.out):
|
66 |
+
os.makedirs(opt.out)
|
67 |
+
json_path = os.path.join(opt.out, 'params.json')
|
68 |
+
if 'debug' not in opt.suffix and path.isfile(json_path):
|
69 |
+
assert opt.resume, 'You are trying to modify a model without resuming: {0}'.format(opt.out)
|
70 |
+
old_dict = json.load(open(json_path))
|
71 |
+
new_dict = vars(opt)
|
72 |
+
# assert old_dict.keys() == new_dict.keys(), 'New configuration keys is different from old one.\nold: {0}\nnew: {1}'.format(old_dict.keys(), new_dict.keys())
|
73 |
+
if new_dict != old_dict:
|
74 |
+
exception_keys = ['command']
|
75 |
+
for key in set(old_dict.keys()).union(set(new_dict.keys())):
|
76 |
+
if key not in exception_keys:
|
77 |
+
old_val = old_dict[key] if key in old_dict else 'not exists(old)'
|
78 |
+
new_val = new_dict[key] if key in old_dict else 'not exists(new)'
|
79 |
+
if old_val != new_val:
|
80 |
+
print('key: {0}, old_val: {1}, new_val: {2}'.format(key, old_val, new_val))
|
81 |
+
if opt.use_cc == False:
|
82 |
+
if not utils.confirm('Please manually confirm'):
|
83 |
+
exit(1)
|
84 |
+
with open(json_path, 'w') as fp:
|
85 |
+
json.dump(vars(opt), fp, indent=0, sort_keys=True)
|
86 |
+
|
87 |
+
|
88 |
+
def build_scenes_name_list_from_opt(opt):
|
89 |
+
if hasattr(opt, 'scene_file') and opt.scene_file is not None:
|
90 |
+
assert os.path.isfile(opt.scene_file), opt.scene_file
|
91 |
+
with open(opt.scene_file, 'r') as f:
|
92 |
+
scenes_list = json.load(f)
|
93 |
+
else:
|
94 |
+
scenes_list = [{'scene': opt.scene, 'seq': opt.seq}]
|
95 |
+
if 'megadepth' in opt.dataset_name:
|
96 |
+
assert opt.info_level in ['rgb', 'rgbd']
|
97 |
+
scenes_name_list = []
|
98 |
+
if opt.info_level == 'rgb':
|
99 |
+
dir_list = ['scene_dir', 'image_dir']
|
100 |
+
elif opt.info_level == 'rgbd':
|
101 |
+
dir_list = ['scene_dir', 'image_dir', 'depth_dir']
|
102 |
+
dir_list = {dir_name: dataset_config[opt.dataset_name][dir_name] for dir_name in dir_list}
|
103 |
+
for item in scenes_list:
|
104 |
+
cur_scene = {key: val.format(item['scene'], item['seq']) for key, val in dir_list.items()}
|
105 |
+
scenes_name_list.append(cur_scene)
|
106 |
+
else:
|
107 |
+
raise NotImplementedError()
|
108 |
+
return scenes_name_list
|
third_party/COTR/COTR/projector/pcd_projector.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
a point cloud projector based on np
|
3 |
+
'''
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from COTR.utils import debug_utils, utils
|
8 |
+
|
9 |
+
|
10 |
+
def render_point_cloud_at_capture(point_cloud, capture, render_type='rgb', return_pcd=False):
|
11 |
+
assert render_type in ['rgb', 'bw', 'depth']
|
12 |
+
if render_type == 'rgb':
|
13 |
+
assert point_cloud.shape[1] == 6
|
14 |
+
else:
|
15 |
+
point_cloud = point_cloud[:, :3]
|
16 |
+
assert point_cloud.shape[1] == 3
|
17 |
+
if render_type in ['bw', 'rgb']:
|
18 |
+
keep_z = False
|
19 |
+
else:
|
20 |
+
keep_z = True
|
21 |
+
|
22 |
+
pcd_2d = PointCloudProjector.pcd_3d_to_pcd_2d_np(point_cloud,
|
23 |
+
capture.intrinsic_mat,
|
24 |
+
capture.extrinsic_mat,
|
25 |
+
capture.size,
|
26 |
+
keep_z=True,
|
27 |
+
crop=True,
|
28 |
+
filter_neg=True,
|
29 |
+
norm_coord=False,
|
30 |
+
return_index=False)
|
31 |
+
reproj = PointCloudProjector.pcd_2d_to_img_2d_np(pcd_2d,
|
32 |
+
capture.size,
|
33 |
+
has_z=True,
|
34 |
+
keep_z=keep_z)
|
35 |
+
if return_pcd:
|
36 |
+
return reproj, pcd_2d
|
37 |
+
else:
|
38 |
+
return reproj
|
39 |
+
|
40 |
+
|
41 |
+
def optical_flow_from_a_to_b(cap_a, cap_b):
|
42 |
+
cap_a_intrinsic = cap_a.pinhole_cam.intrinsic_mat
|
43 |
+
cap_a_img_size = cap_a.pinhole_cam.shape[:2]
|
44 |
+
_h, _w = cap_b.pinhole_cam.shape[:2]
|
45 |
+
x, y = np.meshgrid(
|
46 |
+
np.linspace(0, _w - 1, num=_w),
|
47 |
+
np.linspace(0, _h - 1, num=_h),
|
48 |
+
)
|
49 |
+
coord_map = np.concatenate([np.expand_dims(x, 2), np.expand_dims(y, 2)], axis=2)
|
50 |
+
pcd_from_cap_b = cap_b.get_point_cloud_world_from_depth(coord_map)
|
51 |
+
# pcd_from_cap_b = cap_b.point_cloud_world_w_feat(['pos', 'coord'])
|
52 |
+
optical_flow = PointCloudProjector.pcd_2d_to_img_2d_np(PointCloudProjector.pcd_3d_to_pcd_2d_np(pcd_from_cap_b, cap_a_intrinsic, cap_a.cam_pose.world_to_camera[0:3, :], cap_a_img_size, keep_z=True, crop=True, filter_neg=True, norm_coord=False), cap_a_img_size, has_z=True, keep_z=False)
|
53 |
+
return optical_flow
|
54 |
+
|
55 |
+
|
56 |
+
class PointCloudProjector():
|
57 |
+
def __init__(self):
|
58 |
+
pass
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def pcd_2d_to_pcd_3d_np(pcd, depth, intrinsic, motion=None, return_index=False):
|
62 |
+
assert isinstance(pcd, np.ndarray), 'cannot process data type: {0}'.format(type(pcd))
|
63 |
+
assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic))
|
64 |
+
assert len(pcd.shape) == 2 and pcd.shape[1] >= 2
|
65 |
+
assert len(depth.shape) == 2 and depth.shape[1] == 1
|
66 |
+
assert intrinsic.shape == (3, 3)
|
67 |
+
if motion is not None:
|
68 |
+
assert isinstance(motion, np.ndarray), 'cannot process data type: {0}'.format(type(motion))
|
69 |
+
assert motion.shape == (4, 4)
|
70 |
+
# exec(debug_utils.embed_breakpoint())
|
71 |
+
x, y, z = pcd[:, 0], pcd[:, 1], depth[:, 0]
|
72 |
+
append_ones = np.ones_like(x)
|
73 |
+
xyz = np.stack([x, y, append_ones], axis=1) # shape: [num_points, 3]
|
74 |
+
inv_intrinsic_mat = np.linalg.inv(intrinsic)
|
75 |
+
xyz = np.matmul(inv_intrinsic_mat, xyz.T).T * z[..., None]
|
76 |
+
valid_mask_1 = np.where(xyz[:, 2] > 0)
|
77 |
+
xyz = xyz[valid_mask_1]
|
78 |
+
|
79 |
+
if motion is not None:
|
80 |
+
append_ones = np.ones_like(xyz[:, 0:1])
|
81 |
+
xyzw = np.concatenate([xyz, append_ones], axis=1)
|
82 |
+
xyzw = np.matmul(motion, xyzw.T).T
|
83 |
+
valid_mask_2 = np.where(xyzw[:, 3] != 0)
|
84 |
+
xyzw = xyzw[valid_mask_2]
|
85 |
+
xyzw /= xyzw[:, 3:4]
|
86 |
+
xyz = xyzw[:, 0:3]
|
87 |
+
|
88 |
+
if pcd.shape[1] > 2:
|
89 |
+
features = pcd[:, 2:]
|
90 |
+
try:
|
91 |
+
features = features[valid_mask_1][valid_mask_2]
|
92 |
+
except UnboundLocalError:
|
93 |
+
features = features[valid_mask_1]
|
94 |
+
assert xyz.shape[0] == features.shape[0]
|
95 |
+
xyz = np.concatenate([xyz, features], axis=1)
|
96 |
+
if return_index:
|
97 |
+
points_index = np.arange(pcd.shape[0])[valid_mask_1][valid_mask_2]
|
98 |
+
return xyz, points_index
|
99 |
+
return xyz
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def img_2d_to_pcd_3d_np(depth, intrinsic, img=None, motion=None):
|
103 |
+
'''
|
104 |
+
the function signature is not fully correct, because img is an optional
|
105 |
+
if motion is None, the output pcd is in camera space
|
106 |
+
if motion is camera_to_world, the out pcd is in world space.
|
107 |
+
here the output is pure np array
|
108 |
+
'''
|
109 |
+
|
110 |
+
assert isinstance(depth, np.ndarray), 'cannot process data type: {0}'.format(type(depth))
|
111 |
+
assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic))
|
112 |
+
assert len(depth.shape) == 2
|
113 |
+
assert intrinsic.shape == (3, 3)
|
114 |
+
if img is not None:
|
115 |
+
assert isinstance(img, np.ndarray), 'cannot process data type: {0}'.format(type(img))
|
116 |
+
assert len(img.shape) == 3
|
117 |
+
assert img.shape[:2] == depth.shape[:2], 'feature should have the same resolution as the depth'
|
118 |
+
if motion is not None:
|
119 |
+
assert isinstance(motion, np.ndarray), 'cannot process data type: {0}'.format(type(motion))
|
120 |
+
assert motion.shape == (4, 4)
|
121 |
+
|
122 |
+
pcd_image_space = PointCloudProjector.img_2d_to_pcd_2d_np(depth[..., None], norm_coord=False)
|
123 |
+
valid_mask_1 = np.where(pcd_image_space[:, 2] > 0)
|
124 |
+
pcd_image_space = pcd_image_space[valid_mask_1]
|
125 |
+
xy = pcd_image_space[:, :2]
|
126 |
+
z = pcd_image_space[:, 2:3]
|
127 |
+
if img is not None:
|
128 |
+
_c = img.shape[-1]
|
129 |
+
feat = img.reshape(-1, _c)
|
130 |
+
feat = feat[valid_mask_1]
|
131 |
+
xy = np.concatenate([xy, feat], axis=1)
|
132 |
+
pcd_3d = PointCloudProjector.pcd_2d_to_pcd_3d_np(xy, z, intrinsic, motion=motion)
|
133 |
+
return pcd_3d
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def pcd_3d_to_pcd_2d_np(pcd, intrinsic, extrinsic, size, keep_z: bool, crop: bool = True, filter_neg: bool = True, norm_coord: bool = True, return_index: bool = False):
|
137 |
+
assert isinstance(pcd, np.ndarray), 'cannot process data type: {0}'.format(type(pcd))
|
138 |
+
assert isinstance(intrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(intrinsic))
|
139 |
+
assert isinstance(extrinsic, np.ndarray), 'cannot process data type: {0}'.format(type(extrinsic))
|
140 |
+
assert len(pcd.shape) == 2 and pcd.shape[1] >= 3, 'seems the input pcd is not a valid 3d point cloud: {0}'.format(pcd.shape)
|
141 |
+
|
142 |
+
xyzw = np.concatenate([pcd[:, 0:3], np.ones_like(pcd[:, 0:1])], axis=1)
|
143 |
+
mvp_mat = np.matmul(intrinsic, extrinsic)
|
144 |
+
camera_points = np.matmul(mvp_mat, xyzw.T).T
|
145 |
+
if filter_neg:
|
146 |
+
valid_mask_1 = camera_points[:, 2] > 0.0
|
147 |
+
else:
|
148 |
+
valid_mask_1 = np.ones_like(camera_points[:, 2], dtype=bool)
|
149 |
+
camera_points = camera_points[valid_mask_1]
|
150 |
+
image_points = camera_points / camera_points[:, 2:3]
|
151 |
+
image_points = image_points[:, :2]
|
152 |
+
if crop:
|
153 |
+
valid_mask_2 = (image_points[:, 0] >= 0) * (image_points[:, 0] < size[1] - 1) * (image_points[:, 1] >= 0) * (image_points[:, 1] < size[0] - 1)
|
154 |
+
else:
|
155 |
+
valid_mask_2 = np.ones_like(image_points[:, 0], dtype=bool)
|
156 |
+
if norm_coord:
|
157 |
+
image_points = ((image_points / size[::-1]) * 2) - 1
|
158 |
+
|
159 |
+
if keep_z:
|
160 |
+
image_points = np.concatenate([image_points[valid_mask_2], camera_points[valid_mask_2][:, 2:3], pcd[valid_mask_1][:, 3:][valid_mask_2]], axis=1)
|
161 |
+
else:
|
162 |
+
image_points = np.concatenate([image_points[valid_mask_2], pcd[valid_mask_1][:, 3:][valid_mask_2]], axis=1)
|
163 |
+
# if filter_neg and crop:
|
164 |
+
# exec(debug_utils.embed_breakpoint('pcd_3d_to_pcd_2d_np'))
|
165 |
+
if return_index:
|
166 |
+
points_index = np.arange(pcd.shape[0])[valid_mask_1][valid_mask_2]
|
167 |
+
return image_points, points_index
|
168 |
+
return image_points
|
169 |
+
|
170 |
+
@staticmethod
|
171 |
+
def pcd_2d_to_img_2d_np(pcd, size, has_z=False, keep_z=False):
|
172 |
+
assert len(pcd.shape) == 2 and pcd.shape[-1] >= 2, 'seems the input pcd is not a valid point cloud: {0}'.format(pcd.shape)
|
173 |
+
# assert 0, 'pass Z values in'
|
174 |
+
if has_z:
|
175 |
+
pcd = pcd[pcd[:, 2].argsort()[::-1]]
|
176 |
+
if not keep_z:
|
177 |
+
pcd = np.delete(pcd, [2], axis=1)
|
178 |
+
index_list = np.round(pcd[:, 0:2]).astype(np.int32)
|
179 |
+
index_list[:, 0] = np.clip(index_list[:, 0], 0, size[1] - 1)
|
180 |
+
index_list[:, 1] = np.clip(index_list[:, 1], 0, size[0] - 1)
|
181 |
+
_h, _w, _c = *size, pcd.shape[-1] - 2
|
182 |
+
if _c == 0:
|
183 |
+
canvas = np.zeros((_h, _w, 1))
|
184 |
+
canvas[index_list[:, 1], index_list[:, 0]] = 1.0
|
185 |
+
else:
|
186 |
+
canvas = np.zeros((_h, _w, _c))
|
187 |
+
canvas[index_list[:, 1], index_list[:, 0]] = pcd[:, 2:]
|
188 |
+
|
189 |
+
return canvas
|
190 |
+
|
191 |
+
@staticmethod
|
192 |
+
def img_2d_to_pcd_2d_np(img, norm_coord=True):
|
193 |
+
assert isinstance(img, np.ndarray), 'cannot process data type: {0}'.format(type(img))
|
194 |
+
assert len(img.shape) == 3
|
195 |
+
|
196 |
+
_h, _w, _c = img.shape
|
197 |
+
if norm_coord:
|
198 |
+
x, y = np.meshgrid(
|
199 |
+
np.linspace(-1, 1, num=_w),
|
200 |
+
np.linspace(-1, 1, num=_h),
|
201 |
+
)
|
202 |
+
else:
|
203 |
+
x, y = np.meshgrid(
|
204 |
+
np.linspace(0, _w - 1, num=_w),
|
205 |
+
np.linspace(0, _h - 1, num=_h),
|
206 |
+
)
|
207 |
+
x, y = x.reshape(-1, 1), y.reshape(-1, 1)
|
208 |
+
feat = img.reshape(-1, _c)
|
209 |
+
pcd_2d = np.concatenate([x, y, feat], axis=1)
|
210 |
+
return pcd_2d
|
third_party/COTR/COTR/sfm_scenes/knn_search.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Given one capture in a scene, search for its KNN captures
|
3 |
+
'''
|
4 |
+
|
5 |
+
import os
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
from COTR.utils import debug_utils
|
10 |
+
from COTR.utils.constants import VALID_NN_OVERLAPPING_THRESH
|
11 |
+
|
12 |
+
|
13 |
+
class ReprojRatioKnnSearch():
|
14 |
+
def __init__(self, scene):
|
15 |
+
self.scene = scene
|
16 |
+
self.distance_mat = None
|
17 |
+
self.nn_index = None
|
18 |
+
self._read_dist_mat()
|
19 |
+
self._build_nn_index()
|
20 |
+
|
21 |
+
def _read_dist_mat(self):
|
22 |
+
dist_mat_path = os.path.join(os.path.dirname(os.path.dirname(self.scene.captures[0].depth_path)), 'dist_mat/dist_mat.npy')
|
23 |
+
self.distance_mat = np.load(dist_mat_path)
|
24 |
+
|
25 |
+
def _build_nn_index(self):
|
26 |
+
# argsort is in ascending order, so we take negative
|
27 |
+
self.nn_index = (-1 * self.distance_mat).argsort(axis=1)
|
28 |
+
|
29 |
+
def get_knn(self, query, k, db_mask=None):
|
30 |
+
query_index = self.scene.img_path_to_index_dict[query.img_path]
|
31 |
+
if db_mask is not None:
|
32 |
+
query_mask = np.setdiff1d(np.arange(self.distance_mat[query_index].shape[0]), db_mask)
|
33 |
+
num_pos = (self.distance_mat[query_index] > VALID_NN_OVERLAPPING_THRESH).sum() if db_mask is None else (self.distance_mat[query_index][db_mask] > VALID_NN_OVERLAPPING_THRESH).sum()
|
34 |
+
# we have enough valid NN or not
|
35 |
+
if num_pos > k:
|
36 |
+
if db_mask is None:
|
37 |
+
ind = self.nn_index[query_index][:k + 1]
|
38 |
+
else:
|
39 |
+
temp_dist = self.distance_mat[query_index].copy()
|
40 |
+
temp_dist[query_mask] = -1
|
41 |
+
ind = (-1 * temp_dist).argsort(axis=0)[:k + 1]
|
42 |
+
# remove self
|
43 |
+
if query_index in ind:
|
44 |
+
ind = np.delete(ind, np.argwhere(ind == query_index))
|
45 |
+
else:
|
46 |
+
ind = ind[:k]
|
47 |
+
assert ind.shape[0] <= k, ind.shape[0] > 0
|
48 |
+
else:
|
49 |
+
k = num_pos
|
50 |
+
if db_mask is None:
|
51 |
+
ind = self.nn_index[query_index][:max(k, 1)]
|
52 |
+
else:
|
53 |
+
temp_dist = self.distance_mat[query_index].copy()
|
54 |
+
temp_dist[query_mask] = -1
|
55 |
+
ind = (-1 * temp_dist).argsort(axis=0)[:max(k, 1)]
|
56 |
+
return self.scene.get_captures_given_index_list(ind)
|
third_party/COTR/COTR/sfm_scenes/sfm_scenes.py
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Scene reconstructed from SFM, mainly colmap
|
3 |
+
'''
|
4 |
+
import os
|
5 |
+
import copy
|
6 |
+
import math
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
from numpy.linalg import inv
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from COTR.transformations import transformations
|
13 |
+
from COTR.transformations.transform_basics import Translation, Rotation
|
14 |
+
from COTR.cameras.camera_pose import CameraPose
|
15 |
+
from COTR.utils import debug_utils
|
16 |
+
|
17 |
+
|
18 |
+
class SfmScene():
|
19 |
+
def __init__(self, captures, point_cloud=None):
|
20 |
+
self.captures = captures
|
21 |
+
if isinstance(point_cloud, tuple):
|
22 |
+
self.point_cloud = point_cloud[0]
|
23 |
+
self.point_meta = point_cloud[1]
|
24 |
+
else:
|
25 |
+
self.point_cloud = point_cloud
|
26 |
+
self.img_path_to_index_dict = {}
|
27 |
+
self.img_id_to_index_dict = {}
|
28 |
+
self.fname_to_index_dict = {}
|
29 |
+
self._build_img_X_to_index_dict()
|
30 |
+
|
31 |
+
def __str__(self):
|
32 |
+
string = 'Scene contains {0} captures'.format(len(self.captures))
|
33 |
+
return string
|
34 |
+
|
35 |
+
def __getitem__(self, x):
|
36 |
+
if isinstance(x, str):
|
37 |
+
try:
|
38 |
+
return self.captures[self.img_path_to_index_dict[x]]
|
39 |
+
except:
|
40 |
+
return self.captures[self.fname_to_index_dict[x]]
|
41 |
+
else:
|
42 |
+
return self.captures[x]
|
43 |
+
|
44 |
+
def _build_img_X_to_index_dict(self):
|
45 |
+
assert self.captures is not None, 'There is no captures'
|
46 |
+
for i, cap in enumerate(self.captures):
|
47 |
+
assert cap.img_path not in self.img_path_to_index_dict, 'Image already exists'
|
48 |
+
self.img_path_to_index_dict[cap.img_path] = i
|
49 |
+
assert os.path.basename(cap.img_path) not in self.fname_to_index_dict, 'Image already exists'
|
50 |
+
self.fname_to_index_dict[os.path.basename(cap.img_path)] = i
|
51 |
+
if hasattr(cap, 'image_id'):
|
52 |
+
self.img_id_to_index_dict[cap.image_id] = i
|
53 |
+
|
54 |
+
def get_captures_given_index_list(self, index_list):
|
55 |
+
captures_list = []
|
56 |
+
for i in index_list:
|
57 |
+
captures_list.append(self.captures[i])
|
58 |
+
return captures_list
|
59 |
+
|
60 |
+
def get_covisible_caps(self, cap):
|
61 |
+
assert cap.img_path in self.img_path_to_index_dict
|
62 |
+
covis_img_id = set()
|
63 |
+
point_ids = cap.point3d_id
|
64 |
+
for i in point_ids:
|
65 |
+
covis_img_id = covis_img_id.union(set(self.point_meta[i].image_ids))
|
66 |
+
covis_caps = []
|
67 |
+
for i in covis_img_id:
|
68 |
+
if i in self.img_id_to_index_dict:
|
69 |
+
covis_caps.append(self.captures[self.img_id_to_index_dict[i]])
|
70 |
+
else:
|
71 |
+
pass
|
72 |
+
return covis_caps
|
73 |
+
|
74 |
+
def read_data_to_ram(self, data_list):
|
75 |
+
print('warning: you are going to use a lot of RAM.')
|
76 |
+
sum_bytes = 0.0
|
77 |
+
pbar = tqdm(self.captures, desc='reading data, memory usage {0:.2f} MB'.format(sum_bytes / (1024.0 * 1024.0)))
|
78 |
+
for cap in pbar:
|
79 |
+
if 'image' in data_list:
|
80 |
+
sum_bytes += cap.read_image_to_ram()
|
81 |
+
if 'depth' in data_list:
|
82 |
+
sum_bytes += cap.read_depth_to_ram()
|
83 |
+
if 'pcd' in data_list:
|
84 |
+
sum_bytes += cap.read_pcd_to_ram()
|
85 |
+
pbar.set_description('reading data, memory usage {0:.2f} MB'.format(sum_bytes / (1024.0 * 1024.0)))
|
86 |
+
print('----- total memory usage for images: {0} MB-----'.format(sum_bytes / (1024.0 * 1024.0)))
|
87 |
+
|
third_party/COTR/COTR/trainers/base_trainer.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import abc
|
4 |
+
import time
|
5 |
+
|
6 |
+
import tqdm
|
7 |
+
import torch.nn as nn
|
8 |
+
import tensorboardX
|
9 |
+
|
10 |
+
from COTR.trainers import tensorboard_helper
|
11 |
+
from COTR.utils import utils
|
12 |
+
from COTR.options import options_utils
|
13 |
+
|
14 |
+
|
15 |
+
class BaseTrainer(abc.ABC):
|
16 |
+
'''base trainer class.
|
17 |
+
contains methods for training, validation, and writing output.
|
18 |
+
'''
|
19 |
+
|
20 |
+
def __init__(self, opt, model, optimizer, criterion,
|
21 |
+
train_loader, val_loader):
|
22 |
+
self.opt = opt
|
23 |
+
self.use_cuda = opt.use_cuda
|
24 |
+
self.model = model
|
25 |
+
self.optim = optimizer
|
26 |
+
self.criterion = criterion
|
27 |
+
self.train_loader = train_loader
|
28 |
+
self.val_loader = val_loader
|
29 |
+
self.out = opt.out
|
30 |
+
if not os.path.exists(opt.out):
|
31 |
+
os.makedirs(opt.out)
|
32 |
+
self.epoch = 0
|
33 |
+
self.iteration = 0
|
34 |
+
self.max_iter = opt.max_iter
|
35 |
+
self.valid_iter = opt.valid_iter
|
36 |
+
self.tb_pusher = tensorboard_helper.TensorboardPusher(opt)
|
37 |
+
self.push_opt_to_tb()
|
38 |
+
self.need_resume = opt.resume
|
39 |
+
if self.need_resume:
|
40 |
+
self.resume()
|
41 |
+
if self.opt.load_weights:
|
42 |
+
self.load_pretrained_weights()
|
43 |
+
|
44 |
+
def push_opt_to_tb(self):
|
45 |
+
opt_str = options_utils.opt_to_string(self.opt)
|
46 |
+
tb_datapack = tensorboard_helper.TensorboardDatapack()
|
47 |
+
tb_datapack.set_training(False)
|
48 |
+
tb_datapack.set_iteration(self.iteration)
|
49 |
+
tb_datapack.add_text({'options': opt_str})
|
50 |
+
self.tb_pusher.push_to_tensorboard(tb_datapack)
|
51 |
+
|
52 |
+
@abc.abstractmethod
|
53 |
+
def validate_batch(self, data_pack):
|
54 |
+
pass
|
55 |
+
|
56 |
+
@abc.abstractmethod
|
57 |
+
def validate(self):
|
58 |
+
pass
|
59 |
+
|
60 |
+
@abc.abstractmethod
|
61 |
+
def train_batch(self, data_pack):
|
62 |
+
'''train for one batch of data
|
63 |
+
'''
|
64 |
+
pass
|
65 |
+
|
66 |
+
def train_epoch(self):
|
67 |
+
'''train for one epoch
|
68 |
+
one epoch is iterating the whole training dataset once
|
69 |
+
'''
|
70 |
+
self.model.train()
|
71 |
+
for batch_idx, data_pack in tqdm.tqdm(enumerate(self.train_loader),
|
72 |
+
initial=self.iteration % len(
|
73 |
+
self.train_loader),
|
74 |
+
total=len(self.train_loader),
|
75 |
+
desc='Train epoch={0}'.format(
|
76 |
+
self.epoch),
|
77 |
+
ncols=80,
|
78 |
+
leave=True,
|
79 |
+
):
|
80 |
+
|
81 |
+
# iteration = batch_idx + self.epoch * len(self.train_loader)
|
82 |
+
# if self.iteration != 0 and (iteration - 1) != self.iteration:
|
83 |
+
# continue # for resuming
|
84 |
+
# self.iteration = iteration
|
85 |
+
# self.iteration += 1
|
86 |
+
if self.iteration % self.valid_iter == 0:
|
87 |
+
time.sleep(2) # Prevent possible deadlock during epoch transition
|
88 |
+
self.validate()
|
89 |
+
self.train_batch(data_pack)
|
90 |
+
|
91 |
+
if self.iteration >= self.max_iter:
|
92 |
+
break
|
93 |
+
self.iteration += 1
|
94 |
+
|
95 |
+
def train(self):
|
96 |
+
'''entrance of the whole training process
|
97 |
+
'''
|
98 |
+
max_epoch = int(math.ceil(1. * self.max_iter / len(self.train_loader)))
|
99 |
+
for epoch in tqdm.trange(self.epoch,
|
100 |
+
max_epoch,
|
101 |
+
desc='Train',
|
102 |
+
ncols=80):
|
103 |
+
self.epoch = epoch
|
104 |
+
time.sleep(2) # Prevent possible deadlock during epoch transition
|
105 |
+
self.train_epoch()
|
106 |
+
if self.iteration >= self.max_iter:
|
107 |
+
break
|
108 |
+
|
109 |
+
@abc.abstractmethod
|
110 |
+
def resume(self):
|
111 |
+
pass
|
third_party/COTR/COTR/trainers/cotr_trainer.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
import os.path as osp
|
4 |
+
import time
|
5 |
+
|
6 |
+
import tqdm
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import torchvision.utils as vutils
|
10 |
+
from PIL import Image, ImageDraw
|
11 |
+
|
12 |
+
|
13 |
+
from COTR.utils import utils, debug_utils, constants
|
14 |
+
from COTR.trainers import base_trainer, tensorboard_helper
|
15 |
+
from COTR.projector import pcd_projector
|
16 |
+
|
17 |
+
|
18 |
+
class COTRTrainer(base_trainer.BaseTrainer):
|
19 |
+
def __init__(self, opt, model, optimizer, criterion,
|
20 |
+
train_loader, val_loader):
|
21 |
+
super().__init__(opt, model, optimizer, criterion,
|
22 |
+
train_loader, val_loader)
|
23 |
+
|
24 |
+
def validate_batch(self, data_pack):
|
25 |
+
assert self.model.training is False
|
26 |
+
with torch.no_grad():
|
27 |
+
img = data_pack['image'].cuda()
|
28 |
+
query = data_pack['queries'].cuda()
|
29 |
+
target = data_pack['targets'].cuda()
|
30 |
+
self.optim.zero_grad()
|
31 |
+
pred = self.model(img, query)['pred_corrs']
|
32 |
+
loss = torch.nn.functional.mse_loss(pred, target)
|
33 |
+
if self.opt.cycle_consis and self.opt.bidirectional:
|
34 |
+
cycle = self.model(img, pred)['pred_corrs']
|
35 |
+
mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
|
36 |
+
if mask.sum() > 0:
|
37 |
+
cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
|
38 |
+
loss += cycle_loss
|
39 |
+
elif self.opt.cycle_consis and not self.opt.bidirectional:
|
40 |
+
img_reverse = torch.cat([img[..., constants.MAX_SIZE:], img[..., :constants.MAX_SIZE]], axis=-1)
|
41 |
+
query_reverse = pred.clone()
|
42 |
+
query_reverse[..., 0] = query_reverse[..., 0] - 0.5
|
43 |
+
cycle = self.model(img_reverse, query_reverse)['pred_corrs']
|
44 |
+
cycle[..., 0] = cycle[..., 0] - 0.5
|
45 |
+
mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
|
46 |
+
if mask.sum() > 0:
|
47 |
+
cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
|
48 |
+
loss += cycle_loss
|
49 |
+
loss_data = loss.data.item()
|
50 |
+
if np.isnan(loss_data):
|
51 |
+
print('loss is nan while validating')
|
52 |
+
return loss_data, pred
|
53 |
+
|
54 |
+
def validate(self):
|
55 |
+
'''validate for whole validation dataset
|
56 |
+
'''
|
57 |
+
training = self.model.training
|
58 |
+
self.model.eval()
|
59 |
+
val_loss_list = []
|
60 |
+
for batch_idx, data_pack in tqdm.tqdm(
|
61 |
+
enumerate(self.val_loader), total=len(self.val_loader),
|
62 |
+
desc='Valid iteration=%d' % self.iteration, ncols=80,
|
63 |
+
leave=False):
|
64 |
+
loss_data, pred = self.validate_batch(data_pack)
|
65 |
+
val_loss_list.append(loss_data)
|
66 |
+
mean_loss = np.array(val_loss_list).mean()
|
67 |
+
validation_data = {'val_loss': mean_loss,
|
68 |
+
'pred': pred,
|
69 |
+
}
|
70 |
+
self.push_validation_data(data_pack, validation_data)
|
71 |
+
self.save_model()
|
72 |
+
if training:
|
73 |
+
self.model.train()
|
74 |
+
|
75 |
+
def save_model(self):
|
76 |
+
torch.save({
|
77 |
+
'epoch': self.epoch,
|
78 |
+
'iteration': self.iteration,
|
79 |
+
'optim_state_dict': self.optim.state_dict(),
|
80 |
+
'model_state_dict': self.model.state_dict(),
|
81 |
+
}, osp.join(self.out, 'checkpoint.pth.tar'))
|
82 |
+
if self.iteration % (10 * self.valid_iter) == 0:
|
83 |
+
torch.save({
|
84 |
+
'epoch': self.epoch,
|
85 |
+
'iteration': self.iteration,
|
86 |
+
'optim_state_dict': self.optim.state_dict(),
|
87 |
+
'model_state_dict': self.model.state_dict(),
|
88 |
+
}, osp.join(self.out, f'{self.iteration}_checkpoint.pth.tar'))
|
89 |
+
|
90 |
+
def draw_corrs(self, imgs, corrs, col=(255, 0, 0)):
|
91 |
+
imgs = utils.torch_img_to_np_img(imgs)
|
92 |
+
out = []
|
93 |
+
for img, corr in zip(imgs, corrs):
|
94 |
+
img = np.interp(img, [img.min(), img.max()], [0, 255]).astype(np.uint8)
|
95 |
+
img = Image.fromarray(img)
|
96 |
+
draw = ImageDraw.Draw(img)
|
97 |
+
corr *= np.array([constants.MAX_SIZE * 2, constants.MAX_SIZE, constants.MAX_SIZE * 2, constants.MAX_SIZE])
|
98 |
+
for c in corr:
|
99 |
+
draw.line(c, fill=col)
|
100 |
+
out.append(np.array(img))
|
101 |
+
out = np.array(out) / 255.0
|
102 |
+
return utils.np_img_to_torch_img(out)
|
103 |
+
|
104 |
+
def push_validation_data(self, data_pack, validation_data):
|
105 |
+
val_loss = validation_data['val_loss']
|
106 |
+
pred_corrs = np.concatenate([data_pack['queries'].numpy(), validation_data['pred'].cpu().numpy()], axis=-1)
|
107 |
+
pred_corrs = self.draw_corrs(data_pack['image'], pred_corrs)
|
108 |
+
gt_corrs = np.concatenate([data_pack['queries'].numpy(), data_pack['targets'].cpu().numpy()], axis=-1)
|
109 |
+
gt_corrs = self.draw_corrs(data_pack['image'], gt_corrs, (0, 255, 0))
|
110 |
+
|
111 |
+
gt_img = vutils.make_grid(gt_corrs, normalize=True, scale_each=True)
|
112 |
+
pred_img = vutils.make_grid(pred_corrs, normalize=True, scale_each=True)
|
113 |
+
tb_datapack = tensorboard_helper.TensorboardDatapack()
|
114 |
+
tb_datapack.set_training(False)
|
115 |
+
tb_datapack.set_iteration(self.iteration)
|
116 |
+
tb_datapack.add_scalar({'loss/val': val_loss})
|
117 |
+
tb_datapack.add_image({'image/gt_corrs': gt_img})
|
118 |
+
tb_datapack.add_image({'image/pred_corrs': pred_img})
|
119 |
+
self.tb_pusher.push_to_tensorboard(tb_datapack)
|
120 |
+
|
121 |
+
def train_batch(self, data_pack):
|
122 |
+
'''train for one batch of data
|
123 |
+
'''
|
124 |
+
img = data_pack['image'].cuda()
|
125 |
+
query = data_pack['queries'].cuda()
|
126 |
+
target = data_pack['targets'].cuda()
|
127 |
+
|
128 |
+
self.optim.zero_grad()
|
129 |
+
pred = self.model(img, query)['pred_corrs']
|
130 |
+
loss = torch.nn.functional.mse_loss(pred, target)
|
131 |
+
if self.opt.cycle_consis and self.opt.bidirectional:
|
132 |
+
cycle = self.model(img, pred)['pred_corrs']
|
133 |
+
mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
|
134 |
+
if mask.sum() > 0:
|
135 |
+
cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
|
136 |
+
loss += cycle_loss
|
137 |
+
elif self.opt.cycle_consis and not self.opt.bidirectional:
|
138 |
+
img_reverse = torch.cat([img[..., constants.MAX_SIZE:], img[..., :constants.MAX_SIZE]], axis=-1)
|
139 |
+
query_reverse = pred.clone()
|
140 |
+
query_reverse[..., 0] = query_reverse[..., 0] - 0.5
|
141 |
+
cycle = self.model(img_reverse, query_reverse)['pred_corrs']
|
142 |
+
cycle[..., 0] = cycle[..., 0] - 0.5
|
143 |
+
mask = torch.norm(cycle - query, dim=-1) < 10 / constants.MAX_SIZE
|
144 |
+
if mask.sum() > 0:
|
145 |
+
cycle_loss = torch.nn.functional.mse_loss(cycle[mask], query[mask])
|
146 |
+
loss += cycle_loss
|
147 |
+
loss_data = loss.data.item()
|
148 |
+
if np.isnan(loss_data):
|
149 |
+
print('loss is nan during training')
|
150 |
+
self.optim.zero_grad()
|
151 |
+
else:
|
152 |
+
loss.backward()
|
153 |
+
self.push_training_data(data_pack, pred, target, loss)
|
154 |
+
self.optim.step()
|
155 |
+
|
156 |
+
def push_training_data(self, data_pack, pred, target, loss):
|
157 |
+
tb_datapack = tensorboard_helper.TensorboardDatapack()
|
158 |
+
tb_datapack.set_training(True)
|
159 |
+
tb_datapack.set_iteration(self.iteration)
|
160 |
+
tb_datapack.add_histogram({'distribution/pred': pred})
|
161 |
+
tb_datapack.add_histogram({'distribution/target': target})
|
162 |
+
tb_datapack.add_scalar({'loss/train': loss})
|
163 |
+
self.tb_pusher.push_to_tensorboard(tb_datapack)
|
164 |
+
|
165 |
+
def resume(self):
|
166 |
+
'''resume training:
|
167 |
+
resume from the recorded epoch, iteration, and saved weights.
|
168 |
+
resume from the model with the same name.
|
169 |
+
|
170 |
+
Arguments:
|
171 |
+
opt {[type]} -- [description]
|
172 |
+
'''
|
173 |
+
if hasattr(self.opt, 'load_weights'):
|
174 |
+
assert self.opt.load_weights is None or self.opt.load_weights == False
|
175 |
+
# 1. load check point
|
176 |
+
checkpoint_path = os.path.join(self.opt.out, 'checkpoint.pth.tar')
|
177 |
+
if os.path.isfile(checkpoint_path):
|
178 |
+
checkpoint = torch.load(checkpoint_path)
|
179 |
+
else:
|
180 |
+
raise FileNotFoundError(
|
181 |
+
'model check point cannnot found: {0}'.format(checkpoint_path))
|
182 |
+
# 2. load data
|
183 |
+
self.epoch = checkpoint['epoch']
|
184 |
+
self.iteration = checkpoint['iteration']
|
185 |
+
self.load_pretrained_weights()
|
186 |
+
self.optim.load_state_dict(checkpoint['optim_state_dict'])
|
187 |
+
|
188 |
+
def load_pretrained_weights(self):
|
189 |
+
'''
|
190 |
+
load pretrained weights from another model
|
191 |
+
'''
|
192 |
+
# if hasattr(self.opt, 'resume'):
|
193 |
+
# assert self.opt.resume is False
|
194 |
+
assert os.path.isfile(self.opt.load_weights_path), self.opt.load_weights_path
|
195 |
+
|
196 |
+
saved_weights = torch.load(self.opt.load_weights_path)['model_state_dict']
|
197 |
+
utils.safe_load_weights(self.model, saved_weights)
|
198 |
+
content_list = []
|
199 |
+
content_list += [f'Loaded pretrained weights from {self.opt.load_weights_path}']
|
200 |
+
utils.print_notification(content_list)
|
third_party/COTR/COTR/trainers/tensorboard_helper.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
|
3 |
+
import tensorboardX
|
4 |
+
|
5 |
+
|
6 |
+
class TensorboardDatapack():
|
7 |
+
'''data dictionary for pushing to tb
|
8 |
+
'''
|
9 |
+
|
10 |
+
def __init__(self):
|
11 |
+
self.SCALAR_NAME = 'scalar'
|
12 |
+
self.HISTOGRAM_NAME = 'histogram'
|
13 |
+
self.IMAGE_NAME = 'image'
|
14 |
+
self.TEXT_NAME = 'text'
|
15 |
+
self.datapack = {}
|
16 |
+
self.datapack[self.SCALAR_NAME] = {}
|
17 |
+
self.datapack[self.HISTOGRAM_NAME] = {}
|
18 |
+
self.datapack[self.IMAGE_NAME] = {}
|
19 |
+
self.datapack[self.TEXT_NAME] = {}
|
20 |
+
|
21 |
+
def set_training(self, training):
|
22 |
+
self.training = training
|
23 |
+
|
24 |
+
def set_iteration(self, iteration):
|
25 |
+
self.iteration = iteration
|
26 |
+
|
27 |
+
def add_scalar(self, scalar_dict):
|
28 |
+
self.datapack[self.SCALAR_NAME].update(scalar_dict)
|
29 |
+
|
30 |
+
def add_histogram(self, histogram_dict):
|
31 |
+
self.datapack[self.HISTOGRAM_NAME].update(histogram_dict)
|
32 |
+
|
33 |
+
def add_image(self, image_dict):
|
34 |
+
self.datapack[self.IMAGE_NAME].update(image_dict)
|
35 |
+
|
36 |
+
def add_text(self, text_dict):
|
37 |
+
self.datapack[self.TEXT_NAME].update(text_dict)
|
38 |
+
|
39 |
+
|
40 |
+
class TensorboardHelperBase(abc.ABC):
|
41 |
+
'''abstract base class for tb helpers
|
42 |
+
'''
|
43 |
+
|
44 |
+
def __init__(self, tb_writer):
|
45 |
+
self.tb_writer = tb_writer
|
46 |
+
|
47 |
+
@abc.abstractmethod
|
48 |
+
def add_data(self, tb_datapack):
|
49 |
+
pass
|
50 |
+
|
51 |
+
|
52 |
+
class TensorboardScalarHelper(TensorboardHelperBase):
|
53 |
+
def add_data(self, tb_datapack):
|
54 |
+
scalar_dict = tb_datapack.datapack[tb_datapack.SCALAR_NAME]
|
55 |
+
for key, val in scalar_dict.items():
|
56 |
+
self.tb_writer.add_scalar(
|
57 |
+
key, val, global_step=tb_datapack.iteration)
|
58 |
+
|
59 |
+
|
60 |
+
class TensorboardHistogramHelper(TensorboardHelperBase):
|
61 |
+
def add_data(self, tb_datapack):
|
62 |
+
histogram_dict = tb_datapack.datapack[tb_datapack.HISTOGRAM_NAME]
|
63 |
+
for key, val in histogram_dict.items():
|
64 |
+
self.tb_writer.add_histogram(
|
65 |
+
key, val, global_step=tb_datapack.iteration)
|
66 |
+
|
67 |
+
|
68 |
+
class TensorboardImageHelper(TensorboardHelperBase):
|
69 |
+
def add_data(self, tb_datapack):
|
70 |
+
image_dict = tb_datapack.datapack[tb_datapack.IMAGE_NAME]
|
71 |
+
for key, val in image_dict.items():
|
72 |
+
self.tb_writer.add_image(
|
73 |
+
key, val, global_step=tb_datapack.iteration)
|
74 |
+
|
75 |
+
|
76 |
+
class TensorboardTextHelper(TensorboardHelperBase):
|
77 |
+
def add_data(self, tb_datapack):
|
78 |
+
text_dict = tb_datapack.datapack[tb_datapack.TEXT_NAME]
|
79 |
+
for key, val in text_dict.items():
|
80 |
+
self.tb_writer.add_text(
|
81 |
+
key, val, global_step=tb_datapack.iteration)
|
82 |
+
|
83 |
+
|
84 |
+
class TensorboardPusher():
|
85 |
+
def __init__(self, opt):
|
86 |
+
self.tb_writer = tensorboardX.SummaryWriter(opt.tb_out)
|
87 |
+
scalar_helper = TensorboardScalarHelper(self.tb_writer)
|
88 |
+
histogram_helper = TensorboardHistogramHelper(self.tb_writer)
|
89 |
+
image_helper = TensorboardImageHelper(self.tb_writer)
|
90 |
+
text_helper = TensorboardTextHelper(self.tb_writer)
|
91 |
+
self.helper_list = [scalar_helper,
|
92 |
+
histogram_helper, image_helper, text_helper]
|
93 |
+
|
94 |
+
def push_to_tensorboard(self, tb_datapack):
|
95 |
+
for helper in self.helper_list:
|
96 |
+
helper.add_data(tb_datapack)
|
97 |
+
self.tb_writer.flush()
|
third_party/COTR/COTR/transformations/transform_basics.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from COTR.transformations import transformations
|
4 |
+
from COTR.utils import constants
|
5 |
+
|
6 |
+
|
7 |
+
class Rotation():
|
8 |
+
def __init__(self, quat):
|
9 |
+
"""
|
10 |
+
quaternion format (w, x, y, z)
|
11 |
+
"""
|
12 |
+
assert quat.dtype == np.float32
|
13 |
+
self.quaternion = quat
|
14 |
+
|
15 |
+
def __str__(self):
|
16 |
+
string = '{0}'.format(self.quaternion)
|
17 |
+
return string
|
18 |
+
|
19 |
+
@classmethod
|
20 |
+
def from_matrix(cls, mat):
|
21 |
+
assert isinstance(mat, np.ndarray)
|
22 |
+
if mat.shape == (3, 3):
|
23 |
+
id_mat = np.eye(4)
|
24 |
+
id_mat[0:3, 0:3] = mat
|
25 |
+
mat = id_mat
|
26 |
+
assert mat.shape == (4, 4)
|
27 |
+
quat = transformations.quaternion_from_matrix(mat).astype(constants.DEFAULT_PRECISION)
|
28 |
+
return cls(quat)
|
29 |
+
|
30 |
+
@property
|
31 |
+
def rotation_matrix(self):
|
32 |
+
return transformations.quaternion_matrix(self.quaternion).astype(constants.DEFAULT_PRECISION)
|
33 |
+
|
34 |
+
@rotation_matrix.setter
|
35 |
+
def rotation_matrix(self, mat):
|
36 |
+
assert isinstance(mat, np.ndarray)
|
37 |
+
assert mat.shape == (4, 4)
|
38 |
+
quat = transformations.quaternion_from_matrix(mat)
|
39 |
+
self.quaternion = quat
|
40 |
+
|
41 |
+
@property
|
42 |
+
def quaternion(self):
|
43 |
+
assert isinstance(self._quaternion, np.ndarray)
|
44 |
+
assert self._quaternion.shape == (4,)
|
45 |
+
assert np.isclose(np.linalg.norm(self._quaternion), 1.0), 'self._quaternion is not normalized or valid'
|
46 |
+
return self._quaternion
|
47 |
+
|
48 |
+
@quaternion.setter
|
49 |
+
def quaternion(self, quat):
|
50 |
+
assert isinstance(quat, np.ndarray)
|
51 |
+
assert quat.shape == (4,)
|
52 |
+
if not np.isclose(np.linalg.norm(quat), 1.0):
|
53 |
+
print(f'WARNING: normalizing the input quatternion to unit quaternion: {np.linalg.norm(quat)}')
|
54 |
+
quat = quat / np.linalg.norm(quat)
|
55 |
+
assert np.isclose(np.linalg.norm(quat), 1.0), f'input quaternion is not normalized or valid: {quat}'
|
56 |
+
self._quaternion = quat
|
57 |
+
|
58 |
+
|
59 |
+
class UnstableRotation():
|
60 |
+
def __init__(self, mat):
|
61 |
+
assert isinstance(mat, np.ndarray)
|
62 |
+
if mat.shape == (3, 3):
|
63 |
+
id_mat = np.eye(4)
|
64 |
+
id_mat[0:3, 0:3] = mat
|
65 |
+
mat = id_mat
|
66 |
+
assert mat.shape == (4, 4)
|
67 |
+
mat[:3, 3] = 0
|
68 |
+
self._rotation_matrix = mat
|
69 |
+
|
70 |
+
def __str__(self):
|
71 |
+
string = f'rotation_matrix: {self.rotation_matrix}'
|
72 |
+
return string
|
73 |
+
|
74 |
+
@property
|
75 |
+
def rotation_matrix(self):
|
76 |
+
return self._rotation_matrix
|
77 |
+
|
78 |
+
|
79 |
+
class Translation():
|
80 |
+
def __init__(self, vec):
|
81 |
+
assert vec.dtype == np.float32
|
82 |
+
self.translation_vector = vec
|
83 |
+
|
84 |
+
def __str__(self):
|
85 |
+
string = '{0}'.format(self.translation_vector)
|
86 |
+
return string
|
87 |
+
|
88 |
+
@classmethod
|
89 |
+
def from_matrix(cls, mat):
|
90 |
+
assert isinstance(mat, np.ndarray)
|
91 |
+
assert mat.shape == (4, 4)
|
92 |
+
vec = transformations.translation_from_matrix(mat)
|
93 |
+
return cls(vec)
|
94 |
+
|
95 |
+
@property
|
96 |
+
def translation_matrix(self):
|
97 |
+
return transformations.translation_matrix(self.translation_vector).astype(constants.DEFAULT_PRECISION)
|
98 |
+
|
99 |
+
@translation_matrix.setter
|
100 |
+
def translation_matrix(self, mat):
|
101 |
+
assert isinstance(mat, np.ndarray)
|
102 |
+
assert mat.shape == (4, 4)
|
103 |
+
vec = transformations.translation_from_matrix(mat)
|
104 |
+
self.translation_vector = vec
|
105 |
+
|
106 |
+
@property
|
107 |
+
def translation_vector(self):
|
108 |
+
return self._translation_vector
|
109 |
+
|
110 |
+
@translation_vector.setter
|
111 |
+
def translation_vector(self, vec):
|
112 |
+
assert isinstance(vec, np.ndarray)
|
113 |
+
assert vec.shape == (3,)
|
114 |
+
self._translation_vector = vec
|
third_party/COTR/COTR/transformations/transformations.py
ADDED
@@ -0,0 +1,1951 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# transformations.py
|
3 |
+
|
4 |
+
# Copyright (c) 2006-2019, Christoph Gohlke
|
5 |
+
# Copyright (c) 2006-2019, The Regents of the University of California
|
6 |
+
# Produced at the Laboratory for Fluorescence Dynamics
|
7 |
+
# All rights reserved.
|
8 |
+
#
|
9 |
+
# Redistribution and use in source and binary forms, with or without
|
10 |
+
# modification, are permitted provided that the following conditions are met:
|
11 |
+
#
|
12 |
+
# * Redistributions of source code must retain the above copyright notice,
|
13 |
+
# this list of conditions and the following disclaimer.
|
14 |
+
#
|
15 |
+
# * Redistributions in binary form must reproduce the above copyright notice,
|
16 |
+
# this list of conditions and the following disclaimer in the documentation
|
17 |
+
# and/or other materials provided with the distribution.
|
18 |
+
#
|
19 |
+
# * Neither the name of the copyright holder nor the names of its
|
20 |
+
# contributors may be used to endorse or promote products derived from
|
21 |
+
# this software without specific prior written permission.
|
22 |
+
#
|
23 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
24 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
25 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
26 |
+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
27 |
+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
28 |
+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
29 |
+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
30 |
+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
31 |
+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
32 |
+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
33 |
+
# POSSIBILITY OF SUCH DAMAGE.
|
34 |
+
|
35 |
+
"""Homogeneous Transformation Matrices and Quaternions.
|
36 |
+
|
37 |
+
Transformations is a Python library for calculating 4x4 matrices for
|
38 |
+
translating, rotating, reflecting, scaling, shearing, projecting,
|
39 |
+
orthogonalizing, and superimposing arrays of 3D homogeneous coordinates
|
40 |
+
as well as for converting between rotation matrices, Euler angles,
|
41 |
+
and quaternions. Also includes an Arcball control object and
|
42 |
+
functions to decompose transformation matrices.
|
43 |
+
|
44 |
+
:Author:
|
45 |
+
`Christoph Gohlke <https://www.lfd.uci.edu/~gohlke/>`_
|
46 |
+
|
47 |
+
:Organization:
|
48 |
+
Laboratory for Fluorescence Dynamics. University of California, Irvine
|
49 |
+
|
50 |
+
:License: 3-clause BSD
|
51 |
+
|
52 |
+
:Version: 2019.2.20
|
53 |
+
|
54 |
+
Requirements
|
55 |
+
------------
|
56 |
+
* `CPython 2.7 or 3.5+ <https://www.python.org>`_
|
57 |
+
* `Numpy 1.14 <https://www.numpy.org>`_
|
58 |
+
* A Python distutils compatible C compiler (build)
|
59 |
+
|
60 |
+
Revisions
|
61 |
+
---------
|
62 |
+
2019.1.1
|
63 |
+
Update copyright year.
|
64 |
+
|
65 |
+
Notes
|
66 |
+
-----
|
67 |
+
Transformations.py is no longer actively developed and has a few known issues
|
68 |
+
and numerical instabilities. The module is mostly superseded by other modules
|
69 |
+
for 3D transformations and quaternions:
|
70 |
+
|
71 |
+
* `Scipy.spatial.transform <https://github.com/scipy/scipy/tree/master/
|
72 |
+
scipy/spatial/transform>`_
|
73 |
+
* `Transforms3d <https://github.com/matthew-brett/transforms3d>`_
|
74 |
+
(includes most code of this module)
|
75 |
+
* `Numpy-quaternion <https://github.com/moble/quaternion>`_
|
76 |
+
* `Blender.mathutils <https://docs.blender.org/api/master/mathutils.html>`_
|
77 |
+
|
78 |
+
The API is not stable yet and is expected to change between revisions.
|
79 |
+
|
80 |
+
Python 2.7 and 3.4 are deprecated.
|
81 |
+
|
82 |
+
This Python code is not optimized for speed. Refer to the transformations.c
|
83 |
+
module for a faster implementation of some functions.
|
84 |
+
|
85 |
+
Documentation in HTML format can be generated with epydoc.
|
86 |
+
|
87 |
+
Matrices (M) can be inverted using numpy.linalg.inv(M), be concatenated using
|
88 |
+
numpy.dot(M0, M1), or transform homogeneous coordinate arrays (v) using
|
89 |
+
numpy.dot(M, v) for shape (4, \*) column vectors, respectively
|
90 |
+
numpy.dot(v, M.T) for shape (\*, 4) row vectors ("array of points").
|
91 |
+
|
92 |
+
This module follows the "column vectors on the right" and "row major storage"
|
93 |
+
(C contiguous) conventions. The translation components are in the right column
|
94 |
+
of the transformation matrix, i.e. M[:3, 3].
|
95 |
+
The transpose of the transformation matrices may have to be used to interface
|
96 |
+
with other graphics systems, e.g. OpenGL's glMultMatrixd(). See also [16].
|
97 |
+
|
98 |
+
Calculations are carried out with numpy.float64 precision.
|
99 |
+
|
100 |
+
Vector, point, quaternion, and matrix function arguments are expected to be
|
101 |
+
"array like", i.e. tuple, list, or numpy arrays.
|
102 |
+
|
103 |
+
Return types are numpy arrays unless specified otherwise.
|
104 |
+
|
105 |
+
Angles are in radians unless specified otherwise.
|
106 |
+
|
107 |
+
Quaternions w+ix+jy+kz are represented as [w, x, y, z].
|
108 |
+
|
109 |
+
A triple of Euler angles can be applied/interpreted in 24 ways, which can
|
110 |
+
be specified using a 4 character string or encoded 4-tuple:
|
111 |
+
|
112 |
+
*Axes 4-string*: e.g. 'sxyz' or 'ryxy'
|
113 |
+
|
114 |
+
- first character : rotations are applied to 's'tatic or 'r'otating frame
|
115 |
+
- remaining characters : successive rotation axis 'x', 'y', or 'z'
|
116 |
+
|
117 |
+
*Axes 4-tuple*: e.g. (0, 0, 0, 0) or (1, 1, 1, 1)
|
118 |
+
|
119 |
+
- inner axis: code of axis ('x':0, 'y':1, 'z':2) of rightmost matrix.
|
120 |
+
- parity : even (0) if inner axis 'x' is followed by 'y', 'y' is followed
|
121 |
+
by 'z', or 'z' is followed by 'x'. Otherwise odd (1).
|
122 |
+
- repetition : first and last axis are same (1) or different (0).
|
123 |
+
- frame : rotations are applied to static (0) or rotating (1) frame.
|
124 |
+
|
125 |
+
References
|
126 |
+
----------
|
127 |
+
(1) Matrices and transformations. Ronald Goldman.
|
128 |
+
In "Graphics Gems I", pp 472-475. Morgan Kaufmann, 1990.
|
129 |
+
(2) More matrices and transformations: shear and pseudo-perspective.
|
130 |
+
Ronald Goldman. In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
|
131 |
+
(3) Decomposing a matrix into simple transformations. Spencer Thomas.
|
132 |
+
In "Graphics Gems II", pp 320-323. Morgan Kaufmann, 1991.
|
133 |
+
(4) Recovering the data from the transformation matrix. Ronald Goldman.
|
134 |
+
In "Graphics Gems II", pp 324-331. Morgan Kaufmann, 1991.
|
135 |
+
(5) Euler angle conversion. Ken Shoemake.
|
136 |
+
In "Graphics Gems IV", pp 222-229. Morgan Kaufmann, 1994.
|
137 |
+
(6) Arcball rotation control. Ken Shoemake.
|
138 |
+
In "Graphics Gems IV", pp 175-192. Morgan Kaufmann, 1994.
|
139 |
+
(7) Representing attitude: Euler angles, unit quaternions, and rotation
|
140 |
+
vectors. James Diebel. 2006.
|
141 |
+
(8) A discussion of the solution for the best rotation to relate two sets
|
142 |
+
of vectors. W Kabsch. Acta Cryst. 1978. A34, 827-828.
|
143 |
+
(9) Closed-form solution of absolute orientation using unit quaternions.
|
144 |
+
BKP Horn. J Opt Soc Am A. 1987. 4(4):629-642.
|
145 |
+
(10) Quaternions. Ken Shoemake.
|
146 |
+
http://www.sfu.ca/~jwa3/cmpt461/files/quatut.pdf
|
147 |
+
(11) From quaternion to matrix and back. JMP van Waveren. 2005.
|
148 |
+
http://www.intel.com/cd/ids/developer/asmo-na/eng/293748.htm
|
149 |
+
(12) Uniform random rotations. Ken Shoemake.
|
150 |
+
In "Graphics Gems III", pp 124-132. Morgan Kaufmann, 1992.
|
151 |
+
(13) Quaternion in molecular modeling. CFF Karney.
|
152 |
+
J Mol Graph Mod, 25(5):595-604
|
153 |
+
(14) New method for extracting the quaternion from a rotation matrix.
|
154 |
+
Itzhack Y Bar-Itzhack, J Guid Contr Dynam. 2000. 23(6): 1085-1087.
|
155 |
+
(15) Multiple View Geometry in Computer Vision. Hartley and Zissermann.
|
156 |
+
Cambridge University Press; 2nd Ed. 2004. Chapter 4, Algorithm 4.7, p 130.
|
157 |
+
(16) Column Vectors vs. Row Vectors.
|
158 |
+
http://steve.hollasch.net/cgindex/math/matrix/column-vec.html
|
159 |
+
|
160 |
+
Examples
|
161 |
+
--------
|
162 |
+
>>> alpha, beta, gamma = 0.123, -1.234, 2.345
|
163 |
+
>>> origin, xaxis, yaxis, zaxis = [0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]
|
164 |
+
>>> I = identity_matrix()
|
165 |
+
>>> Rx = rotation_matrix(alpha, xaxis)
|
166 |
+
>>> Ry = rotation_matrix(beta, yaxis)
|
167 |
+
>>> Rz = rotation_matrix(gamma, zaxis)
|
168 |
+
>>> R = concatenate_matrices(Rx, Ry, Rz)
|
169 |
+
>>> euler = euler_from_matrix(R, 'rxyz')
|
170 |
+
>>> numpy.allclose([alpha, beta, gamma], euler)
|
171 |
+
True
|
172 |
+
>>> Re = euler_matrix(alpha, beta, gamma, 'rxyz')
|
173 |
+
>>> is_same_transform(R, Re)
|
174 |
+
True
|
175 |
+
>>> al, be, ga = euler_from_matrix(Re, 'rxyz')
|
176 |
+
>>> is_same_transform(Re, euler_matrix(al, be, ga, 'rxyz'))
|
177 |
+
True
|
178 |
+
>>> qx = quaternion_about_axis(alpha, xaxis)
|
179 |
+
>>> qy = quaternion_about_axis(beta, yaxis)
|
180 |
+
>>> qz = quaternion_about_axis(gamma, zaxis)
|
181 |
+
>>> q = quaternion_multiply(qx, qy)
|
182 |
+
>>> q = quaternion_multiply(q, qz)
|
183 |
+
>>> Rq = quaternion_matrix(q)
|
184 |
+
>>> is_same_transform(R, Rq)
|
185 |
+
True
|
186 |
+
>>> S = scale_matrix(1.23, origin)
|
187 |
+
>>> T = translation_matrix([1, 2, 3])
|
188 |
+
>>> Z = shear_matrix(beta, xaxis, origin, zaxis)
|
189 |
+
>>> R = random_rotation_matrix(numpy.random.rand(3))
|
190 |
+
>>> M = concatenate_matrices(T, R, Z, S)
|
191 |
+
>>> scale, shear, angles, trans, persp = decompose_matrix(M)
|
192 |
+
>>> numpy.allclose(scale, 1.23)
|
193 |
+
True
|
194 |
+
>>> numpy.allclose(trans, [1, 2, 3])
|
195 |
+
True
|
196 |
+
>>> numpy.allclose(shear, [0, math.tan(beta), 0])
|
197 |
+
True
|
198 |
+
>>> is_same_transform(R, euler_matrix(axes='sxyz', *angles))
|
199 |
+
True
|
200 |
+
>>> M1 = compose_matrix(scale, shear, angles, trans, persp)
|
201 |
+
>>> is_same_transform(M, M1)
|
202 |
+
True
|
203 |
+
>>> v0, v1 = random_vector(3), random_vector(3)
|
204 |
+
>>> M = rotation_matrix(angle_between_vectors(v0, v1), vector_product(v0, v1))
|
205 |
+
>>> v2 = numpy.dot(v0, M[:3,:3].T)
|
206 |
+
>>> numpy.allclose(unit_vector(v1), unit_vector(v2))
|
207 |
+
True
|
208 |
+
|
209 |
+
"""
|
210 |
+
|
211 |
+
from __future__ import division, print_function
|
212 |
+
|
213 |
+
__version__ = '2019.2.20'
|
214 |
+
__docformat__ = 'restructuredtext en'
|
215 |
+
|
216 |
+
import math
|
217 |
+
|
218 |
+
import numpy
|
219 |
+
|
220 |
+
|
221 |
+
def identity_matrix():
|
222 |
+
"""Return 4x4 identity/unit matrix.
|
223 |
+
|
224 |
+
>>> I = identity_matrix()
|
225 |
+
>>> numpy.allclose(I, numpy.dot(I, I))
|
226 |
+
True
|
227 |
+
>>> numpy.sum(I), numpy.trace(I)
|
228 |
+
(4.0, 4.0)
|
229 |
+
>>> numpy.allclose(I, numpy.identity(4))
|
230 |
+
True
|
231 |
+
|
232 |
+
"""
|
233 |
+
return numpy.identity(4)
|
234 |
+
|
235 |
+
|
236 |
+
def translation_matrix(direction):
|
237 |
+
"""Return matrix to translate by direction vector.
|
238 |
+
|
239 |
+
>>> v = numpy.random.random(3) - 0.5
|
240 |
+
>>> numpy.allclose(v, translation_matrix(v)[:3, 3])
|
241 |
+
True
|
242 |
+
|
243 |
+
"""
|
244 |
+
M = numpy.identity(4)
|
245 |
+
M[:3, 3] = direction[:3]
|
246 |
+
return M
|
247 |
+
|
248 |
+
|
249 |
+
def translation_from_matrix(matrix):
|
250 |
+
"""Return translation vector from translation matrix.
|
251 |
+
|
252 |
+
>>> v0 = numpy.random.random(3) - 0.5
|
253 |
+
>>> v1 = translation_from_matrix(translation_matrix(v0))
|
254 |
+
>>> numpy.allclose(v0, v1)
|
255 |
+
True
|
256 |
+
|
257 |
+
"""
|
258 |
+
return numpy.array(matrix, copy=False)[:3, 3].copy()
|
259 |
+
|
260 |
+
|
261 |
+
def reflection_matrix(point, normal):
|
262 |
+
"""Return matrix to mirror at plane defined by point and normal vector.
|
263 |
+
|
264 |
+
>>> v0 = numpy.random.random(4) - 0.5
|
265 |
+
>>> v0[3] = 1.
|
266 |
+
>>> v1 = numpy.random.random(3) - 0.5
|
267 |
+
>>> R = reflection_matrix(v0, v1)
|
268 |
+
>>> numpy.allclose(2, numpy.trace(R))
|
269 |
+
True
|
270 |
+
>>> numpy.allclose(v0, numpy.dot(R, v0))
|
271 |
+
True
|
272 |
+
>>> v2 = v0.copy()
|
273 |
+
>>> v2[:3] += v1
|
274 |
+
>>> v3 = v0.copy()
|
275 |
+
>>> v2[:3] -= v1
|
276 |
+
>>> numpy.allclose(v2, numpy.dot(R, v3))
|
277 |
+
True
|
278 |
+
|
279 |
+
"""
|
280 |
+
normal = unit_vector(normal[:3])
|
281 |
+
M = numpy.identity(4)
|
282 |
+
M[:3, :3] -= 2.0 * numpy.outer(normal, normal)
|
283 |
+
M[:3, 3] = (2.0 * numpy.dot(point[:3], normal)) * normal
|
284 |
+
return M
|
285 |
+
|
286 |
+
|
287 |
+
def reflection_from_matrix(matrix):
|
288 |
+
"""Return mirror plane point and normal vector from reflection matrix.
|
289 |
+
|
290 |
+
>>> v0 = numpy.random.random(3) - 0.5
|
291 |
+
>>> v1 = numpy.random.random(3) - 0.5
|
292 |
+
>>> M0 = reflection_matrix(v0, v1)
|
293 |
+
>>> point, normal = reflection_from_matrix(M0)
|
294 |
+
>>> M1 = reflection_matrix(point, normal)
|
295 |
+
>>> is_same_transform(M0, M1)
|
296 |
+
True
|
297 |
+
|
298 |
+
"""
|
299 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)
|
300 |
+
# normal: unit eigenvector corresponding to eigenvalue -1
|
301 |
+
w, V = numpy.linalg.eig(M[:3, :3])
|
302 |
+
i = numpy.where(abs(numpy.real(w) + 1.0) < 1e-8)[0]
|
303 |
+
if not len(i):
|
304 |
+
raise ValueError('no unit eigenvector corresponding to eigenvalue -1')
|
305 |
+
normal = numpy.real(V[:, i[0]]).squeeze()
|
306 |
+
# point: any unit eigenvector corresponding to eigenvalue 1
|
307 |
+
w, V = numpy.linalg.eig(M)
|
308 |
+
i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
|
309 |
+
if not len(i):
|
310 |
+
raise ValueError('no unit eigenvector corresponding to eigenvalue 1')
|
311 |
+
point = numpy.real(V[:, i[-1]]).squeeze()
|
312 |
+
point /= point[3]
|
313 |
+
return point, normal
|
314 |
+
|
315 |
+
|
316 |
+
def rotation_matrix(angle, direction, point=None):
|
317 |
+
"""Return matrix to rotate about axis defined by point and direction.
|
318 |
+
|
319 |
+
>>> R = rotation_matrix(math.pi/2, [0, 0, 1], [1, 0, 0])
|
320 |
+
>>> numpy.allclose(numpy.dot(R, [0, 0, 0, 1]), [1, -1, 0, 1])
|
321 |
+
True
|
322 |
+
>>> angle = (random.random() - 0.5) * (2*math.pi)
|
323 |
+
>>> direc = numpy.random.random(3) - 0.5
|
324 |
+
>>> point = numpy.random.random(3) - 0.5
|
325 |
+
>>> R0 = rotation_matrix(angle, direc, point)
|
326 |
+
>>> R1 = rotation_matrix(angle-2*math.pi, direc, point)
|
327 |
+
>>> is_same_transform(R0, R1)
|
328 |
+
True
|
329 |
+
>>> R0 = rotation_matrix(angle, direc, point)
|
330 |
+
>>> R1 = rotation_matrix(-angle, -direc, point)
|
331 |
+
>>> is_same_transform(R0, R1)
|
332 |
+
True
|
333 |
+
>>> I = numpy.identity(4, numpy.float64)
|
334 |
+
>>> numpy.allclose(I, rotation_matrix(math.pi*2, direc))
|
335 |
+
True
|
336 |
+
>>> numpy.allclose(2, numpy.trace(rotation_matrix(math.pi/2,
|
337 |
+
... direc, point)))
|
338 |
+
True
|
339 |
+
|
340 |
+
"""
|
341 |
+
sina = math.sin(angle)
|
342 |
+
cosa = math.cos(angle)
|
343 |
+
direction = unit_vector(direction[:3])
|
344 |
+
# rotation matrix around unit vector
|
345 |
+
R = numpy.diag([cosa, cosa, cosa])
|
346 |
+
R += numpy.outer(direction, direction) * (1.0 - cosa)
|
347 |
+
direction *= sina
|
348 |
+
R += numpy.array([[0.0, -direction[2], direction[1]],
|
349 |
+
[direction[2], 0.0, -direction[0]],
|
350 |
+
[-direction[1], direction[0], 0.0]])
|
351 |
+
M = numpy.identity(4)
|
352 |
+
M[:3, :3] = R
|
353 |
+
if point is not None:
|
354 |
+
# rotation not around origin
|
355 |
+
point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
|
356 |
+
M[:3, 3] = point - numpy.dot(R, point)
|
357 |
+
return M
|
358 |
+
|
359 |
+
|
360 |
+
def rotation_from_matrix(matrix):
|
361 |
+
"""Return rotation angle and axis from rotation matrix.
|
362 |
+
|
363 |
+
>>> angle = (random.random() - 0.5) * (2*math.pi)
|
364 |
+
>>> direc = numpy.random.random(3) - 0.5
|
365 |
+
>>> point = numpy.random.random(3) - 0.5
|
366 |
+
>>> R0 = rotation_matrix(angle, direc, point)
|
367 |
+
>>> angle, direc, point = rotation_from_matrix(R0)
|
368 |
+
>>> R1 = rotation_matrix(angle, direc, point)
|
369 |
+
>>> is_same_transform(R0, R1)
|
370 |
+
True
|
371 |
+
|
372 |
+
"""
|
373 |
+
R = numpy.array(matrix, dtype=numpy.float64, copy=False)
|
374 |
+
R33 = R[:3, :3]
|
375 |
+
# direction: unit eigenvector of R33 corresponding to eigenvalue of 1
|
376 |
+
w, W = numpy.linalg.eig(R33.T)
|
377 |
+
i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
|
378 |
+
if not len(i):
|
379 |
+
raise ValueError('no unit eigenvector corresponding to eigenvalue 1')
|
380 |
+
direction = numpy.real(W[:, i[-1]]).squeeze()
|
381 |
+
# point: unit eigenvector of R33 corresponding to eigenvalue of 1
|
382 |
+
w, Q = numpy.linalg.eig(R)
|
383 |
+
i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
|
384 |
+
if not len(i):
|
385 |
+
raise ValueError('no unit eigenvector corresponding to eigenvalue 1')
|
386 |
+
point = numpy.real(Q[:, i[-1]]).squeeze()
|
387 |
+
point /= point[3]
|
388 |
+
# rotation angle depending on direction
|
389 |
+
cosa = (numpy.trace(R33) - 1.0) / 2.0
|
390 |
+
if abs(direction[2]) > 1e-8:
|
391 |
+
sina = (R[1, 0] + (cosa-1.0)*direction[0]*direction[1]) / direction[2]
|
392 |
+
elif abs(direction[1]) > 1e-8:
|
393 |
+
sina = (R[0, 2] + (cosa-1.0)*direction[0]*direction[2]) / direction[1]
|
394 |
+
else:
|
395 |
+
sina = (R[2, 1] + (cosa-1.0)*direction[1]*direction[2]) / direction[0]
|
396 |
+
angle = math.atan2(sina, cosa)
|
397 |
+
return angle, direction, point
|
398 |
+
|
399 |
+
|
400 |
+
def scale_matrix(factor, origin=None, direction=None):
|
401 |
+
"""Return matrix to scale by factor around origin in direction.
|
402 |
+
|
403 |
+
Use factor -1 for point symmetry.
|
404 |
+
|
405 |
+
>>> v = (numpy.random.rand(4, 5) - 0.5) * 20
|
406 |
+
>>> v[3] = 1
|
407 |
+
>>> S = scale_matrix(-1.234)
|
408 |
+
>>> numpy.allclose(numpy.dot(S, v)[:3], -1.234*v[:3])
|
409 |
+
True
|
410 |
+
>>> factor = random.random() * 10 - 5
|
411 |
+
>>> origin = numpy.random.random(3) - 0.5
|
412 |
+
>>> direct = numpy.random.random(3) - 0.5
|
413 |
+
>>> S = scale_matrix(factor, origin)
|
414 |
+
>>> S = scale_matrix(factor, origin, direct)
|
415 |
+
|
416 |
+
"""
|
417 |
+
if direction is None:
|
418 |
+
# uniform scaling
|
419 |
+
M = numpy.diag([factor, factor, factor, 1.0])
|
420 |
+
if origin is not None:
|
421 |
+
M[:3, 3] = origin[:3]
|
422 |
+
M[:3, 3] *= 1.0 - factor
|
423 |
+
else:
|
424 |
+
# nonuniform scaling
|
425 |
+
direction = unit_vector(direction[:3])
|
426 |
+
factor = 1.0 - factor
|
427 |
+
M = numpy.identity(4)
|
428 |
+
M[:3, :3] -= factor * numpy.outer(direction, direction)
|
429 |
+
if origin is not None:
|
430 |
+
M[:3, 3] = (factor * numpy.dot(origin[:3], direction)) * direction
|
431 |
+
return M
|
432 |
+
|
433 |
+
|
434 |
+
def scale_from_matrix(matrix):
|
435 |
+
"""Return scaling factor, origin and direction from scaling matrix.
|
436 |
+
|
437 |
+
>>> factor = random.random() * 10 - 5
|
438 |
+
>>> origin = numpy.random.random(3) - 0.5
|
439 |
+
>>> direct = numpy.random.random(3) - 0.5
|
440 |
+
>>> S0 = scale_matrix(factor, origin)
|
441 |
+
>>> factor, origin, direction = scale_from_matrix(S0)
|
442 |
+
>>> S1 = scale_matrix(factor, origin, direction)
|
443 |
+
>>> is_same_transform(S0, S1)
|
444 |
+
True
|
445 |
+
>>> S0 = scale_matrix(factor, origin, direct)
|
446 |
+
>>> factor, origin, direction = scale_from_matrix(S0)
|
447 |
+
>>> S1 = scale_matrix(factor, origin, direction)
|
448 |
+
>>> is_same_transform(S0, S1)
|
449 |
+
True
|
450 |
+
|
451 |
+
"""
|
452 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)
|
453 |
+
M33 = M[:3, :3]
|
454 |
+
factor = numpy.trace(M33) - 2.0
|
455 |
+
try:
|
456 |
+
# direction: unit eigenvector corresponding to eigenvalue factor
|
457 |
+
w, V = numpy.linalg.eig(M33)
|
458 |
+
i = numpy.where(abs(numpy.real(w) - factor) < 1e-8)[0][0]
|
459 |
+
direction = numpy.real(V[:, i]).squeeze()
|
460 |
+
direction /= vector_norm(direction)
|
461 |
+
except IndexError:
|
462 |
+
# uniform scaling
|
463 |
+
factor = (factor + 2.0) / 3.0
|
464 |
+
direction = None
|
465 |
+
# origin: any eigenvector corresponding to eigenvalue 1
|
466 |
+
w, V = numpy.linalg.eig(M)
|
467 |
+
i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
|
468 |
+
if not len(i):
|
469 |
+
raise ValueError('no eigenvector corresponding to eigenvalue 1')
|
470 |
+
origin = numpy.real(V[:, i[-1]]).squeeze()
|
471 |
+
origin /= origin[3]
|
472 |
+
return factor, origin, direction
|
473 |
+
|
474 |
+
|
475 |
+
def projection_matrix(point, normal, direction=None,
|
476 |
+
perspective=None, pseudo=False):
|
477 |
+
"""Return matrix to project onto plane defined by point and normal.
|
478 |
+
|
479 |
+
Using either perspective point, projection direction, or none of both.
|
480 |
+
|
481 |
+
If pseudo is True, perspective projections will preserve relative depth
|
482 |
+
such that Perspective = dot(Orthogonal, PseudoPerspective).
|
483 |
+
|
484 |
+
>>> P = projection_matrix([0, 0, 0], [1, 0, 0])
|
485 |
+
>>> numpy.allclose(P[1:, 1:], numpy.identity(4)[1:, 1:])
|
486 |
+
True
|
487 |
+
>>> point = numpy.random.random(3) - 0.5
|
488 |
+
>>> normal = numpy.random.random(3) - 0.5
|
489 |
+
>>> direct = numpy.random.random(3) - 0.5
|
490 |
+
>>> persp = numpy.random.random(3) - 0.5
|
491 |
+
>>> P0 = projection_matrix(point, normal)
|
492 |
+
>>> P1 = projection_matrix(point, normal, direction=direct)
|
493 |
+
>>> P2 = projection_matrix(point, normal, perspective=persp)
|
494 |
+
>>> P3 = projection_matrix(point, normal, perspective=persp, pseudo=True)
|
495 |
+
>>> is_same_transform(P2, numpy.dot(P0, P3))
|
496 |
+
True
|
497 |
+
>>> P = projection_matrix([3, 0, 0], [1, 1, 0], [1, 0, 0])
|
498 |
+
>>> v0 = (numpy.random.rand(4, 5) - 0.5) * 20
|
499 |
+
>>> v0[3] = 1
|
500 |
+
>>> v1 = numpy.dot(P, v0)
|
501 |
+
>>> numpy.allclose(v1[1], v0[1])
|
502 |
+
True
|
503 |
+
>>> numpy.allclose(v1[0], 3-v1[1])
|
504 |
+
True
|
505 |
+
|
506 |
+
"""
|
507 |
+
M = numpy.identity(4)
|
508 |
+
point = numpy.array(point[:3], dtype=numpy.float64, copy=False)
|
509 |
+
normal = unit_vector(normal[:3])
|
510 |
+
if perspective is not None:
|
511 |
+
# perspective projection
|
512 |
+
perspective = numpy.array(perspective[:3], dtype=numpy.float64,
|
513 |
+
copy=False)
|
514 |
+
M[0, 0] = M[1, 1] = M[2, 2] = numpy.dot(perspective-point, normal)
|
515 |
+
M[:3, :3] -= numpy.outer(perspective, normal)
|
516 |
+
if pseudo:
|
517 |
+
# preserve relative depth
|
518 |
+
M[:3, :3] -= numpy.outer(normal, normal)
|
519 |
+
M[:3, 3] = numpy.dot(point, normal) * (perspective+normal)
|
520 |
+
else:
|
521 |
+
M[:3, 3] = numpy.dot(point, normal) * perspective
|
522 |
+
M[3, :3] = -normal
|
523 |
+
M[3, 3] = numpy.dot(perspective, normal)
|
524 |
+
elif direction is not None:
|
525 |
+
# parallel projection
|
526 |
+
direction = numpy.array(direction[:3], dtype=numpy.float64, copy=False)
|
527 |
+
scale = numpy.dot(direction, normal)
|
528 |
+
M[:3, :3] -= numpy.outer(direction, normal) / scale
|
529 |
+
M[:3, 3] = direction * (numpy.dot(point, normal) / scale)
|
530 |
+
else:
|
531 |
+
# orthogonal projection
|
532 |
+
M[:3, :3] -= numpy.outer(normal, normal)
|
533 |
+
M[:3, 3] = numpy.dot(point, normal) * normal
|
534 |
+
return M
|
535 |
+
|
536 |
+
|
537 |
+
def projection_from_matrix(matrix, pseudo=False):
|
538 |
+
"""Return projection plane and perspective point from projection matrix.
|
539 |
+
|
540 |
+
Return values are same as arguments for projection_matrix function:
|
541 |
+
point, normal, direction, perspective, and pseudo.
|
542 |
+
|
543 |
+
>>> point = numpy.random.random(3) - 0.5
|
544 |
+
>>> normal = numpy.random.random(3) - 0.5
|
545 |
+
>>> direct = numpy.random.random(3) - 0.5
|
546 |
+
>>> persp = numpy.random.random(3) - 0.5
|
547 |
+
>>> P0 = projection_matrix(point, normal)
|
548 |
+
>>> result = projection_from_matrix(P0)
|
549 |
+
>>> P1 = projection_matrix(*result)
|
550 |
+
>>> is_same_transform(P0, P1)
|
551 |
+
True
|
552 |
+
>>> P0 = projection_matrix(point, normal, direct)
|
553 |
+
>>> result = projection_from_matrix(P0)
|
554 |
+
>>> P1 = projection_matrix(*result)
|
555 |
+
>>> is_same_transform(P0, P1)
|
556 |
+
True
|
557 |
+
>>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=False)
|
558 |
+
>>> result = projection_from_matrix(P0, pseudo=False)
|
559 |
+
>>> P1 = projection_matrix(*result)
|
560 |
+
>>> is_same_transform(P0, P1)
|
561 |
+
True
|
562 |
+
>>> P0 = projection_matrix(point, normal, perspective=persp, pseudo=True)
|
563 |
+
>>> result = projection_from_matrix(P0, pseudo=True)
|
564 |
+
>>> P1 = projection_matrix(*result)
|
565 |
+
>>> is_same_transform(P0, P1)
|
566 |
+
True
|
567 |
+
|
568 |
+
"""
|
569 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)
|
570 |
+
M33 = M[:3, :3]
|
571 |
+
w, V = numpy.linalg.eig(M)
|
572 |
+
i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
|
573 |
+
if not pseudo and len(i):
|
574 |
+
# point: any eigenvector corresponding to eigenvalue 1
|
575 |
+
point = numpy.real(V[:, i[-1]]).squeeze()
|
576 |
+
point /= point[3]
|
577 |
+
# direction: unit eigenvector corresponding to eigenvalue 0
|
578 |
+
w, V = numpy.linalg.eig(M33)
|
579 |
+
i = numpy.where(abs(numpy.real(w)) < 1e-8)[0]
|
580 |
+
if not len(i):
|
581 |
+
raise ValueError('no eigenvector corresponding to eigenvalue 0')
|
582 |
+
direction = numpy.real(V[:, i[0]]).squeeze()
|
583 |
+
direction /= vector_norm(direction)
|
584 |
+
# normal: unit eigenvector of M33.T corresponding to eigenvalue 0
|
585 |
+
w, V = numpy.linalg.eig(M33.T)
|
586 |
+
i = numpy.where(abs(numpy.real(w)) < 1e-8)[0]
|
587 |
+
if len(i):
|
588 |
+
# parallel projection
|
589 |
+
normal = numpy.real(V[:, i[0]]).squeeze()
|
590 |
+
normal /= vector_norm(normal)
|
591 |
+
return point, normal, direction, None, False
|
592 |
+
else:
|
593 |
+
# orthogonal projection, where normal equals direction vector
|
594 |
+
return point, direction, None, None, False
|
595 |
+
else:
|
596 |
+
# perspective projection
|
597 |
+
i = numpy.where(abs(numpy.real(w)) > 1e-8)[0]
|
598 |
+
if not len(i):
|
599 |
+
raise ValueError(
|
600 |
+
'no eigenvector not corresponding to eigenvalue 0')
|
601 |
+
point = numpy.real(V[:, i[-1]]).squeeze()
|
602 |
+
point /= point[3]
|
603 |
+
normal = - M[3, :3]
|
604 |
+
perspective = M[:3, 3] / numpy.dot(point[:3], normal)
|
605 |
+
if pseudo:
|
606 |
+
perspective -= normal
|
607 |
+
return point, normal, None, perspective, pseudo
|
608 |
+
|
609 |
+
|
610 |
+
def clip_matrix(left, right, bottom, top, near, far, perspective=False):
|
611 |
+
"""Return matrix to obtain normalized device coordinates from frustum.
|
612 |
+
|
613 |
+
The frustum bounds are axis-aligned along x (left, right),
|
614 |
+
y (bottom, top) and z (near, far).
|
615 |
+
|
616 |
+
Normalized device coordinates are in range [-1, 1] if coordinates are
|
617 |
+
inside the frustum.
|
618 |
+
|
619 |
+
If perspective is True the frustum is a truncated pyramid with the
|
620 |
+
perspective point at origin and direction along z axis, otherwise an
|
621 |
+
orthographic canonical view volume (a box).
|
622 |
+
|
623 |
+
Homogeneous coordinates transformed by the perspective clip matrix
|
624 |
+
need to be dehomogenized (divided by w coordinate).
|
625 |
+
|
626 |
+
>>> frustum = numpy.random.rand(6)
|
627 |
+
>>> frustum[1] += frustum[0]
|
628 |
+
>>> frustum[3] += frustum[2]
|
629 |
+
>>> frustum[5] += frustum[4]
|
630 |
+
>>> M = clip_matrix(perspective=False, *frustum)
|
631 |
+
>>> numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1])
|
632 |
+
array([-1., -1., -1., 1.])
|
633 |
+
>>> numpy.dot(M, [frustum[1], frustum[3], frustum[5], 1])
|
634 |
+
array([ 1., 1., 1., 1.])
|
635 |
+
>>> M = clip_matrix(perspective=True, *frustum)
|
636 |
+
>>> v = numpy.dot(M, [frustum[0], frustum[2], frustum[4], 1])
|
637 |
+
>>> v / v[3]
|
638 |
+
array([-1., -1., -1., 1.])
|
639 |
+
>>> v = numpy.dot(M, [frustum[1], frustum[3], frustum[4], 1])
|
640 |
+
>>> v / v[3]
|
641 |
+
array([ 1., 1., -1., 1.])
|
642 |
+
|
643 |
+
"""
|
644 |
+
if left >= right or bottom >= top or near >= far:
|
645 |
+
raise ValueError('invalid frustum')
|
646 |
+
if perspective:
|
647 |
+
if near <= _EPS:
|
648 |
+
raise ValueError('invalid frustum: near <= 0')
|
649 |
+
t = 2.0 * near
|
650 |
+
M = [[t/(left-right), 0.0, (right+left)/(right-left), 0.0],
|
651 |
+
[0.0, t/(bottom-top), (top+bottom)/(top-bottom), 0.0],
|
652 |
+
[0.0, 0.0, (far+near)/(near-far), t*far/(far-near)],
|
653 |
+
[0.0, 0.0, -1.0, 0.0]]
|
654 |
+
else:
|
655 |
+
M = [[2.0/(right-left), 0.0, 0.0, (right+left)/(left-right)],
|
656 |
+
[0.0, 2.0/(top-bottom), 0.0, (top+bottom)/(bottom-top)],
|
657 |
+
[0.0, 0.0, 2.0/(far-near), (far+near)/(near-far)],
|
658 |
+
[0.0, 0.0, 0.0, 1.0]]
|
659 |
+
return numpy.array(M)
|
660 |
+
|
661 |
+
|
662 |
+
def shear_matrix(angle, direction, point, normal):
|
663 |
+
"""Return matrix to shear by angle along direction vector on shear plane.
|
664 |
+
|
665 |
+
The shear plane is defined by a point and normal vector. The direction
|
666 |
+
vector must be orthogonal to the plane's normal vector.
|
667 |
+
|
668 |
+
A point P is transformed by the shear matrix into P" such that
|
669 |
+
the vector P-P" is parallel to the direction vector and its extent is
|
670 |
+
given by the angle of P-P'-P", where P' is the orthogonal projection
|
671 |
+
of P onto the shear plane.
|
672 |
+
|
673 |
+
>>> angle = (random.random() - 0.5) * 4*math.pi
|
674 |
+
>>> direct = numpy.random.random(3) - 0.5
|
675 |
+
>>> point = numpy.random.random(3) - 0.5
|
676 |
+
>>> normal = numpy.cross(direct, numpy.random.random(3))
|
677 |
+
>>> S = shear_matrix(angle, direct, point, normal)
|
678 |
+
>>> numpy.allclose(1, numpy.linalg.det(S))
|
679 |
+
True
|
680 |
+
|
681 |
+
"""
|
682 |
+
normal = unit_vector(normal[:3])
|
683 |
+
direction = unit_vector(direction[:3])
|
684 |
+
if abs(numpy.dot(normal, direction)) > 1e-6:
|
685 |
+
raise ValueError('direction and normal vectors are not orthogonal')
|
686 |
+
angle = math.tan(angle)
|
687 |
+
M = numpy.identity(4)
|
688 |
+
M[:3, :3] += angle * numpy.outer(direction, normal)
|
689 |
+
M[:3, 3] = -angle * numpy.dot(point[:3], normal) * direction
|
690 |
+
return M
|
691 |
+
|
692 |
+
|
693 |
+
def shear_from_matrix(matrix):
|
694 |
+
"""Return shear angle, direction and plane from shear matrix.
|
695 |
+
|
696 |
+
>>> angle = (random.random() - 0.5) * 4*math.pi
|
697 |
+
>>> direct = numpy.random.random(3) - 0.5
|
698 |
+
>>> point = numpy.random.random(3) - 0.5
|
699 |
+
>>> normal = numpy.cross(direct, numpy.random.random(3))
|
700 |
+
>>> S0 = shear_matrix(angle, direct, point, normal)
|
701 |
+
>>> angle, direct, point, normal = shear_from_matrix(S0)
|
702 |
+
>>> S1 = shear_matrix(angle, direct, point, normal)
|
703 |
+
>>> is_same_transform(S0, S1)
|
704 |
+
True
|
705 |
+
|
706 |
+
"""
|
707 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)
|
708 |
+
M33 = M[:3, :3]
|
709 |
+
# normal: cross independent eigenvectors corresponding to the eigenvalue 1
|
710 |
+
w, V = numpy.linalg.eig(M33)
|
711 |
+
i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-4)[0]
|
712 |
+
if len(i) < 2:
|
713 |
+
raise ValueError('no two linear independent eigenvectors found %s' % w)
|
714 |
+
V = numpy.real(V[:, i]).squeeze().T
|
715 |
+
lenorm = -1.0
|
716 |
+
for i0, i1 in ((0, 1), (0, 2), (1, 2)):
|
717 |
+
n = numpy.cross(V[i0], V[i1])
|
718 |
+
w = vector_norm(n)
|
719 |
+
if w > lenorm:
|
720 |
+
lenorm = w
|
721 |
+
normal = n
|
722 |
+
normal /= lenorm
|
723 |
+
# direction and angle
|
724 |
+
direction = numpy.dot(M33 - numpy.identity(3), normal)
|
725 |
+
angle = vector_norm(direction)
|
726 |
+
direction /= angle
|
727 |
+
angle = math.atan(angle)
|
728 |
+
# point: eigenvector corresponding to eigenvalue 1
|
729 |
+
w, V = numpy.linalg.eig(M)
|
730 |
+
i = numpy.where(abs(numpy.real(w) - 1.0) < 1e-8)[0]
|
731 |
+
if not len(i):
|
732 |
+
raise ValueError('no eigenvector corresponding to eigenvalue 1')
|
733 |
+
point = numpy.real(V[:, i[-1]]).squeeze()
|
734 |
+
point /= point[3]
|
735 |
+
return angle, direction, point, normal
|
736 |
+
|
737 |
+
|
738 |
+
def decompose_matrix(matrix):
|
739 |
+
"""Return sequence of transformations from transformation matrix.
|
740 |
+
|
741 |
+
matrix : array_like
|
742 |
+
Non-degenerative homogeneous transformation matrix
|
743 |
+
|
744 |
+
Return tuple of:
|
745 |
+
scale : vector of 3 scaling factors
|
746 |
+
shear : list of shear factors for x-y, x-z, y-z axes
|
747 |
+
angles : list of Euler angles about static x, y, z axes
|
748 |
+
translate : translation vector along x, y, z axes
|
749 |
+
perspective : perspective partition of matrix
|
750 |
+
|
751 |
+
Raise ValueError if matrix is of wrong type or degenerative.
|
752 |
+
|
753 |
+
>>> T0 = translation_matrix([1, 2, 3])
|
754 |
+
>>> scale, shear, angles, trans, persp = decompose_matrix(T0)
|
755 |
+
>>> T1 = translation_matrix(trans)
|
756 |
+
>>> numpy.allclose(T0, T1)
|
757 |
+
True
|
758 |
+
>>> S = scale_matrix(0.123)
|
759 |
+
>>> scale, shear, angles, trans, persp = decompose_matrix(S)
|
760 |
+
>>> scale[0]
|
761 |
+
0.123
|
762 |
+
>>> R0 = euler_matrix(1, 2, 3)
|
763 |
+
>>> scale, shear, angles, trans, persp = decompose_matrix(R0)
|
764 |
+
>>> R1 = euler_matrix(*angles)
|
765 |
+
>>> numpy.allclose(R0, R1)
|
766 |
+
True
|
767 |
+
|
768 |
+
"""
|
769 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=True).T
|
770 |
+
if abs(M[3, 3]) < _EPS:
|
771 |
+
raise ValueError('M[3, 3] is zero')
|
772 |
+
M /= M[3, 3]
|
773 |
+
P = M.copy()
|
774 |
+
P[:, 3] = 0.0, 0.0, 0.0, 1.0
|
775 |
+
if not numpy.linalg.det(P):
|
776 |
+
raise ValueError('matrix is singular')
|
777 |
+
|
778 |
+
scale = numpy.zeros((3, ))
|
779 |
+
shear = [0.0, 0.0, 0.0]
|
780 |
+
angles = [0.0, 0.0, 0.0]
|
781 |
+
|
782 |
+
if any(abs(M[:3, 3]) > _EPS):
|
783 |
+
perspective = numpy.dot(M[:, 3], numpy.linalg.inv(P.T))
|
784 |
+
M[:, 3] = 0.0, 0.0, 0.0, 1.0
|
785 |
+
else:
|
786 |
+
perspective = numpy.array([0.0, 0.0, 0.0, 1.0])
|
787 |
+
|
788 |
+
translate = M[3, :3].copy()
|
789 |
+
M[3, :3] = 0.0
|
790 |
+
|
791 |
+
row = M[:3, :3].copy()
|
792 |
+
scale[0] = vector_norm(row[0])
|
793 |
+
row[0] /= scale[0]
|
794 |
+
shear[0] = numpy.dot(row[0], row[1])
|
795 |
+
row[1] -= row[0] * shear[0]
|
796 |
+
scale[1] = vector_norm(row[1])
|
797 |
+
row[1] /= scale[1]
|
798 |
+
shear[0] /= scale[1]
|
799 |
+
shear[1] = numpy.dot(row[0], row[2])
|
800 |
+
row[2] -= row[0] * shear[1]
|
801 |
+
shear[2] = numpy.dot(row[1], row[2])
|
802 |
+
row[2] -= row[1] * shear[2]
|
803 |
+
scale[2] = vector_norm(row[2])
|
804 |
+
row[2] /= scale[2]
|
805 |
+
shear[1:] /= scale[2]
|
806 |
+
|
807 |
+
if numpy.dot(row[0], numpy.cross(row[1], row[2])) < 0:
|
808 |
+
numpy.negative(scale, scale)
|
809 |
+
numpy.negative(row, row)
|
810 |
+
|
811 |
+
angles[1] = math.asin(-row[0, 2])
|
812 |
+
if math.cos(angles[1]):
|
813 |
+
angles[0] = math.atan2(row[1, 2], row[2, 2])
|
814 |
+
angles[2] = math.atan2(row[0, 1], row[0, 0])
|
815 |
+
else:
|
816 |
+
# angles[0] = math.atan2(row[1, 0], row[1, 1])
|
817 |
+
angles[0] = math.atan2(-row[2, 1], row[1, 1])
|
818 |
+
angles[2] = 0.0
|
819 |
+
|
820 |
+
return scale, shear, angles, translate, perspective
|
821 |
+
|
822 |
+
|
823 |
+
def compose_matrix(scale=None, shear=None, angles=None, translate=None,
|
824 |
+
perspective=None):
|
825 |
+
"""Return transformation matrix from sequence of transformations.
|
826 |
+
|
827 |
+
This is the inverse of the decompose_matrix function.
|
828 |
+
|
829 |
+
Sequence of transformations:
|
830 |
+
scale : vector of 3 scaling factors
|
831 |
+
shear : list of shear factors for x-y, x-z, y-z axes
|
832 |
+
angles : list of Euler angles about static x, y, z axes
|
833 |
+
translate : translation vector along x, y, z axes
|
834 |
+
perspective : perspective partition of matrix
|
835 |
+
|
836 |
+
>>> scale = numpy.random.random(3) - 0.5
|
837 |
+
>>> shear = numpy.random.random(3) - 0.5
|
838 |
+
>>> angles = (numpy.random.random(3) - 0.5) * (2*math.pi)
|
839 |
+
>>> trans = numpy.random.random(3) - 0.5
|
840 |
+
>>> persp = numpy.random.random(4) - 0.5
|
841 |
+
>>> M0 = compose_matrix(scale, shear, angles, trans, persp)
|
842 |
+
>>> result = decompose_matrix(M0)
|
843 |
+
>>> M1 = compose_matrix(*result)
|
844 |
+
>>> is_same_transform(M0, M1)
|
845 |
+
True
|
846 |
+
|
847 |
+
"""
|
848 |
+
M = numpy.identity(4)
|
849 |
+
if perspective is not None:
|
850 |
+
P = numpy.identity(4)
|
851 |
+
P[3, :] = perspective[:4]
|
852 |
+
M = numpy.dot(M, P)
|
853 |
+
if translate is not None:
|
854 |
+
T = numpy.identity(4)
|
855 |
+
T[:3, 3] = translate[:3]
|
856 |
+
M = numpy.dot(M, T)
|
857 |
+
if angles is not None:
|
858 |
+
R = euler_matrix(angles[0], angles[1], angles[2], 'sxyz')
|
859 |
+
M = numpy.dot(M, R)
|
860 |
+
if shear is not None:
|
861 |
+
Z = numpy.identity(4)
|
862 |
+
Z[1, 2] = shear[2]
|
863 |
+
Z[0, 2] = shear[1]
|
864 |
+
Z[0, 1] = shear[0]
|
865 |
+
M = numpy.dot(M, Z)
|
866 |
+
if scale is not None:
|
867 |
+
S = numpy.identity(4)
|
868 |
+
S[0, 0] = scale[0]
|
869 |
+
S[1, 1] = scale[1]
|
870 |
+
S[2, 2] = scale[2]
|
871 |
+
M = numpy.dot(M, S)
|
872 |
+
M /= M[3, 3]
|
873 |
+
return M
|
874 |
+
|
875 |
+
|
876 |
+
def orthogonalization_matrix(lengths, angles):
|
877 |
+
"""Return orthogonalization matrix for crystallographic cell coordinates.
|
878 |
+
|
879 |
+
Angles are expected in degrees.
|
880 |
+
|
881 |
+
The de-orthogonalization matrix is the inverse.
|
882 |
+
|
883 |
+
>>> O = orthogonalization_matrix([10, 10, 10], [90, 90, 90])
|
884 |
+
>>> numpy.allclose(O[:3, :3], numpy.identity(3, float) * 10)
|
885 |
+
True
|
886 |
+
>>> O = orthogonalization_matrix([9.8, 12.0, 15.5], [87.2, 80.7, 69.7])
|
887 |
+
>>> numpy.allclose(numpy.sum(O), 43.063229)
|
888 |
+
True
|
889 |
+
|
890 |
+
"""
|
891 |
+
a, b, c = lengths
|
892 |
+
angles = numpy.radians(angles)
|
893 |
+
sina, sinb, _ = numpy.sin(angles)
|
894 |
+
cosa, cosb, cosg = numpy.cos(angles)
|
895 |
+
co = (cosa * cosb - cosg) / (sina * sinb)
|
896 |
+
return numpy.array([
|
897 |
+
[a*sinb*math.sqrt(1.0-co*co), 0.0, 0.0, 0.0],
|
898 |
+
[-a*sinb*co, b*sina, 0.0, 0.0],
|
899 |
+
[a*cosb, b*cosa, c, 0.0],
|
900 |
+
[0.0, 0.0, 0.0, 1.0]])
|
901 |
+
|
902 |
+
|
903 |
+
def affine_matrix_from_points(v0, v1, shear=True, scale=True, usesvd=True):
|
904 |
+
"""Return affine transform matrix to register two point sets.
|
905 |
+
|
906 |
+
v0 and v1 are shape (ndims, \*) arrays of at least ndims non-homogeneous
|
907 |
+
coordinates, where ndims is the dimensionality of the coordinate space.
|
908 |
+
|
909 |
+
If shear is False, a similarity transformation matrix is returned.
|
910 |
+
If also scale is False, a rigid/Euclidean transformation matrix
|
911 |
+
is returned.
|
912 |
+
|
913 |
+
By default the algorithm by Hartley and Zissermann [15] is used.
|
914 |
+
If usesvd is True, similarity and Euclidean transformation matrices
|
915 |
+
are calculated by minimizing the weighted sum of squared deviations
|
916 |
+
(RMSD) according to the algorithm by Kabsch [8].
|
917 |
+
Otherwise, and if ndims is 3, the quaternion based algorithm by Horn [9]
|
918 |
+
is used, which is slower when using this Python implementation.
|
919 |
+
|
920 |
+
The returned matrix performs rotation, translation and uniform scaling
|
921 |
+
(if specified).
|
922 |
+
|
923 |
+
>>> v0 = [[0, 1031, 1031, 0], [0, 0, 1600, 1600]]
|
924 |
+
>>> v1 = [[675, 826, 826, 677], [55, 52, 281, 277]]
|
925 |
+
>>> affine_matrix_from_points(v0, v1)
|
926 |
+
array([[ 0.14549, 0.00062, 675.50008],
|
927 |
+
[ 0.00048, 0.14094, 53.24971],
|
928 |
+
[ 0. , 0. , 1. ]])
|
929 |
+
>>> T = translation_matrix(numpy.random.random(3)-0.5)
|
930 |
+
>>> R = random_rotation_matrix(numpy.random.random(3))
|
931 |
+
>>> S = scale_matrix(random.random())
|
932 |
+
>>> M = concatenate_matrices(T, R, S)
|
933 |
+
>>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20
|
934 |
+
>>> v0[3] = 1
|
935 |
+
>>> v1 = numpy.dot(M, v0)
|
936 |
+
>>> v0[:3] += numpy.random.normal(0, 1e-8, 300).reshape(3, -1)
|
937 |
+
>>> M = affine_matrix_from_points(v0[:3], v1[:3])
|
938 |
+
>>> numpy.allclose(v1, numpy.dot(M, v0))
|
939 |
+
True
|
940 |
+
|
941 |
+
More examples in superimposition_matrix()
|
942 |
+
|
943 |
+
"""
|
944 |
+
v0 = numpy.array(v0, dtype=numpy.float64, copy=True)
|
945 |
+
v1 = numpy.array(v1, dtype=numpy.float64, copy=True)
|
946 |
+
|
947 |
+
ndims = v0.shape[0]
|
948 |
+
if ndims < 2 or v0.shape[1] < ndims or v0.shape != v1.shape:
|
949 |
+
raise ValueError('input arrays are of wrong shape or type')
|
950 |
+
|
951 |
+
# move centroids to origin
|
952 |
+
t0 = -numpy.mean(v0, axis=1)
|
953 |
+
M0 = numpy.identity(ndims+1)
|
954 |
+
M0[:ndims, ndims] = t0
|
955 |
+
v0 += t0.reshape(ndims, 1)
|
956 |
+
t1 = -numpy.mean(v1, axis=1)
|
957 |
+
M1 = numpy.identity(ndims+1)
|
958 |
+
M1[:ndims, ndims] = t1
|
959 |
+
v1 += t1.reshape(ndims, 1)
|
960 |
+
|
961 |
+
if shear:
|
962 |
+
# Affine transformation
|
963 |
+
A = numpy.concatenate((v0, v1), axis=0)
|
964 |
+
u, s, vh = numpy.linalg.svd(A.T)
|
965 |
+
vh = vh[:ndims].T
|
966 |
+
B = vh[:ndims]
|
967 |
+
C = vh[ndims:2*ndims]
|
968 |
+
t = numpy.dot(C, numpy.linalg.pinv(B))
|
969 |
+
t = numpy.concatenate((t, numpy.zeros((ndims, 1))), axis=1)
|
970 |
+
M = numpy.vstack((t, ((0.0,)*ndims) + (1.0,)))
|
971 |
+
elif usesvd or ndims != 3:
|
972 |
+
# Rigid transformation via SVD of covariance matrix
|
973 |
+
u, s, vh = numpy.linalg.svd(numpy.dot(v1, v0.T))
|
974 |
+
# rotation matrix from SVD orthonormal bases
|
975 |
+
R = numpy.dot(u, vh)
|
976 |
+
if numpy.linalg.det(R) < 0.0:
|
977 |
+
# R does not constitute right handed system
|
978 |
+
R -= numpy.outer(u[:, ndims-1], vh[ndims-1, :]*2.0)
|
979 |
+
s[-1] *= -1.0
|
980 |
+
# homogeneous transformation matrix
|
981 |
+
M = numpy.identity(ndims+1)
|
982 |
+
M[:ndims, :ndims] = R
|
983 |
+
else:
|
984 |
+
# Rigid transformation matrix via quaternion
|
985 |
+
# compute symmetric matrix N
|
986 |
+
xx, yy, zz = numpy.sum(v0 * v1, axis=1)
|
987 |
+
xy, yz, zx = numpy.sum(v0 * numpy.roll(v1, -1, axis=0), axis=1)
|
988 |
+
xz, yx, zy = numpy.sum(v0 * numpy.roll(v1, -2, axis=0), axis=1)
|
989 |
+
N = [[xx+yy+zz, 0.0, 0.0, 0.0],
|
990 |
+
[yz-zy, xx-yy-zz, 0.0, 0.0],
|
991 |
+
[zx-xz, xy+yx, yy-xx-zz, 0.0],
|
992 |
+
[xy-yx, zx+xz, yz+zy, zz-xx-yy]]
|
993 |
+
# quaternion: eigenvector corresponding to most positive eigenvalue
|
994 |
+
w, V = numpy.linalg.eigh(N)
|
995 |
+
q = V[:, numpy.argmax(w)]
|
996 |
+
q /= vector_norm(q) # unit quaternion
|
997 |
+
# homogeneous transformation matrix
|
998 |
+
M = quaternion_matrix(q)
|
999 |
+
|
1000 |
+
if scale and not shear:
|
1001 |
+
# Affine transformation; scale is ratio of RMS deviations from centroid
|
1002 |
+
v0 *= v0
|
1003 |
+
v1 *= v1
|
1004 |
+
M[:ndims, :ndims] *= math.sqrt(numpy.sum(v1) / numpy.sum(v0))
|
1005 |
+
|
1006 |
+
# move centroids back
|
1007 |
+
M = numpy.dot(numpy.linalg.inv(M1), numpy.dot(M, M0))
|
1008 |
+
M /= M[ndims, ndims]
|
1009 |
+
return M
|
1010 |
+
|
1011 |
+
|
1012 |
+
def superimposition_matrix(v0, v1, scale=False, usesvd=True):
|
1013 |
+
"""Return matrix to transform given 3D point set into second point set.
|
1014 |
+
|
1015 |
+
v0 and v1 are shape (3, \*) or (4, \*) arrays of at least 3 points.
|
1016 |
+
|
1017 |
+
The parameters scale and usesvd are explained in the more general
|
1018 |
+
affine_matrix_from_points function.
|
1019 |
+
|
1020 |
+
The returned matrix is a similarity or Euclidean transformation matrix.
|
1021 |
+
This function has a fast C implementation in transformations.c.
|
1022 |
+
|
1023 |
+
>>> v0 = numpy.random.rand(3, 10)
|
1024 |
+
>>> M = superimposition_matrix(v0, v0)
|
1025 |
+
>>> numpy.allclose(M, numpy.identity(4))
|
1026 |
+
True
|
1027 |
+
>>> R = random_rotation_matrix(numpy.random.random(3))
|
1028 |
+
>>> v0 = [[1,0,0], [0,1,0], [0,0,1], [1,1,1]]
|
1029 |
+
>>> v1 = numpy.dot(R, v0)
|
1030 |
+
>>> M = superimposition_matrix(v0, v1)
|
1031 |
+
>>> numpy.allclose(v1, numpy.dot(M, v0))
|
1032 |
+
True
|
1033 |
+
>>> v0 = (numpy.random.rand(4, 100) - 0.5) * 20
|
1034 |
+
>>> v0[3] = 1
|
1035 |
+
>>> v1 = numpy.dot(R, v0)
|
1036 |
+
>>> M = superimposition_matrix(v0, v1)
|
1037 |
+
>>> numpy.allclose(v1, numpy.dot(M, v0))
|
1038 |
+
True
|
1039 |
+
>>> S = scale_matrix(random.random())
|
1040 |
+
>>> T = translation_matrix(numpy.random.random(3)-0.5)
|
1041 |
+
>>> M = concatenate_matrices(T, R, S)
|
1042 |
+
>>> v1 = numpy.dot(M, v0)
|
1043 |
+
>>> v0[:3] += numpy.random.normal(0, 1e-9, 300).reshape(3, -1)
|
1044 |
+
>>> M = superimposition_matrix(v0, v1, scale=True)
|
1045 |
+
>>> numpy.allclose(v1, numpy.dot(M, v0))
|
1046 |
+
True
|
1047 |
+
>>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False)
|
1048 |
+
>>> numpy.allclose(v1, numpy.dot(M, v0))
|
1049 |
+
True
|
1050 |
+
>>> v = numpy.empty((4, 100, 3))
|
1051 |
+
>>> v[:, :, 0] = v0
|
1052 |
+
>>> M = superimposition_matrix(v0, v1, scale=True, usesvd=False)
|
1053 |
+
>>> numpy.allclose(v1, numpy.dot(M, v[:, :, 0]))
|
1054 |
+
True
|
1055 |
+
|
1056 |
+
"""
|
1057 |
+
v0 = numpy.array(v0, dtype=numpy.float64, copy=False)[:3]
|
1058 |
+
v1 = numpy.array(v1, dtype=numpy.float64, copy=False)[:3]
|
1059 |
+
return affine_matrix_from_points(v0, v1, shear=False,
|
1060 |
+
scale=scale, usesvd=usesvd)
|
1061 |
+
|
1062 |
+
|
1063 |
+
def euler_matrix(ai, aj, ak, axes='sxyz'):
|
1064 |
+
"""Return homogeneous rotation matrix from Euler angles and axis sequence.
|
1065 |
+
|
1066 |
+
ai, aj, ak : Euler's roll, pitch and yaw angles
|
1067 |
+
axes : One of 24 axis sequences as string or encoded tuple
|
1068 |
+
|
1069 |
+
>>> R = euler_matrix(1, 2, 3, 'syxz')
|
1070 |
+
>>> numpy.allclose(numpy.sum(R[0]), -1.34786452)
|
1071 |
+
True
|
1072 |
+
>>> R = euler_matrix(1, 2, 3, (0, 1, 0, 1))
|
1073 |
+
>>> numpy.allclose(numpy.sum(R[0]), -0.383436184)
|
1074 |
+
True
|
1075 |
+
>>> ai, aj, ak = (4*math.pi) * (numpy.random.random(3) - 0.5)
|
1076 |
+
>>> for axes in _AXES2TUPLE.keys():
|
1077 |
+
... R = euler_matrix(ai, aj, ak, axes)
|
1078 |
+
>>> for axes in _TUPLE2AXES.keys():
|
1079 |
+
... R = euler_matrix(ai, aj, ak, axes)
|
1080 |
+
|
1081 |
+
"""
|
1082 |
+
try:
|
1083 |
+
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes]
|
1084 |
+
except (AttributeError, KeyError):
|
1085 |
+
_TUPLE2AXES[axes] # noqa: validation
|
1086 |
+
firstaxis, parity, repetition, frame = axes
|
1087 |
+
|
1088 |
+
i = firstaxis
|
1089 |
+
j = _NEXT_AXIS[i+parity]
|
1090 |
+
k = _NEXT_AXIS[i-parity+1]
|
1091 |
+
|
1092 |
+
if frame:
|
1093 |
+
ai, ak = ak, ai
|
1094 |
+
if parity:
|
1095 |
+
ai, aj, ak = -ai, -aj, -ak
|
1096 |
+
|
1097 |
+
si, sj, sk = math.sin(ai), math.sin(aj), math.sin(ak)
|
1098 |
+
ci, cj, ck = math.cos(ai), math.cos(aj), math.cos(ak)
|
1099 |
+
cc, cs = ci*ck, ci*sk
|
1100 |
+
sc, ss = si*ck, si*sk
|
1101 |
+
|
1102 |
+
M = numpy.identity(4)
|
1103 |
+
if repetition:
|
1104 |
+
M[i, i] = cj
|
1105 |
+
M[i, j] = sj*si
|
1106 |
+
M[i, k] = sj*ci
|
1107 |
+
M[j, i] = sj*sk
|
1108 |
+
M[j, j] = -cj*ss+cc
|
1109 |
+
M[j, k] = -cj*cs-sc
|
1110 |
+
M[k, i] = -sj*ck
|
1111 |
+
M[k, j] = cj*sc+cs
|
1112 |
+
M[k, k] = cj*cc-ss
|
1113 |
+
else:
|
1114 |
+
M[i, i] = cj*ck
|
1115 |
+
M[i, j] = sj*sc-cs
|
1116 |
+
M[i, k] = sj*cc+ss
|
1117 |
+
M[j, i] = cj*sk
|
1118 |
+
M[j, j] = sj*ss+cc
|
1119 |
+
M[j, k] = sj*cs-sc
|
1120 |
+
M[k, i] = -sj
|
1121 |
+
M[k, j] = cj*si
|
1122 |
+
M[k, k] = cj*ci
|
1123 |
+
return M
|
1124 |
+
|
1125 |
+
|
1126 |
+
def euler_from_matrix(matrix, axes='sxyz'):
|
1127 |
+
"""Return Euler angles from rotation matrix for specified axis sequence.
|
1128 |
+
|
1129 |
+
axes : One of 24 axis sequences as string or encoded tuple
|
1130 |
+
|
1131 |
+
Note that many Euler angle triplets can describe one matrix.
|
1132 |
+
|
1133 |
+
>>> R0 = euler_matrix(1, 2, 3, 'syxz')
|
1134 |
+
>>> al, be, ga = euler_from_matrix(R0, 'syxz')
|
1135 |
+
>>> R1 = euler_matrix(al, be, ga, 'syxz')
|
1136 |
+
>>> numpy.allclose(R0, R1)
|
1137 |
+
True
|
1138 |
+
>>> angles = (4*math.pi) * (numpy.random.random(3) - 0.5)
|
1139 |
+
>>> for axes in _AXES2TUPLE.keys():
|
1140 |
+
... R0 = euler_matrix(axes=axes, *angles)
|
1141 |
+
... R1 = euler_matrix(axes=axes, *euler_from_matrix(R0, axes))
|
1142 |
+
... if not numpy.allclose(R0, R1): print(axes, "failed")
|
1143 |
+
|
1144 |
+
"""
|
1145 |
+
try:
|
1146 |
+
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
|
1147 |
+
except (AttributeError, KeyError):
|
1148 |
+
_TUPLE2AXES[axes] # noqa: validation
|
1149 |
+
firstaxis, parity, repetition, frame = axes
|
1150 |
+
|
1151 |
+
i = firstaxis
|
1152 |
+
j = _NEXT_AXIS[i+parity]
|
1153 |
+
k = _NEXT_AXIS[i-parity+1]
|
1154 |
+
|
1155 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:3, :3]
|
1156 |
+
if repetition:
|
1157 |
+
sy = math.sqrt(M[i, j]*M[i, j] + M[i, k]*M[i, k])
|
1158 |
+
if sy > _EPS:
|
1159 |
+
ax = math.atan2(M[i, j], M[i, k])
|
1160 |
+
ay = math.atan2(sy, M[i, i])
|
1161 |
+
az = math.atan2(M[j, i], -M[k, i])
|
1162 |
+
else:
|
1163 |
+
ax = math.atan2(-M[j, k], M[j, j])
|
1164 |
+
ay = math.atan2(sy, M[i, i])
|
1165 |
+
az = 0.0
|
1166 |
+
else:
|
1167 |
+
cy = math.sqrt(M[i, i]*M[i, i] + M[j, i]*M[j, i])
|
1168 |
+
if cy > _EPS:
|
1169 |
+
ax = math.atan2(M[k, j], M[k, k])
|
1170 |
+
ay = math.atan2(-M[k, i], cy)
|
1171 |
+
az = math.atan2(M[j, i], M[i, i])
|
1172 |
+
else:
|
1173 |
+
ax = math.atan2(-M[j, k], M[j, j])
|
1174 |
+
ay = math.atan2(-M[k, i], cy)
|
1175 |
+
az = 0.0
|
1176 |
+
|
1177 |
+
if parity:
|
1178 |
+
ax, ay, az = -ax, -ay, -az
|
1179 |
+
if frame:
|
1180 |
+
ax, az = az, ax
|
1181 |
+
return ax, ay, az
|
1182 |
+
|
1183 |
+
|
1184 |
+
def euler_from_quaternion(quaternion, axes='sxyz'):
|
1185 |
+
"""Return Euler angles from quaternion for specified axis sequence.
|
1186 |
+
|
1187 |
+
>>> angles = euler_from_quaternion([0.99810947, 0.06146124, 0, 0])
|
1188 |
+
>>> numpy.allclose(angles, [0.123, 0, 0])
|
1189 |
+
True
|
1190 |
+
|
1191 |
+
"""
|
1192 |
+
return euler_from_matrix(quaternion_matrix(quaternion), axes)
|
1193 |
+
|
1194 |
+
|
1195 |
+
def quaternion_from_euler(ai, aj, ak, axes='sxyz'):
|
1196 |
+
"""Return quaternion from Euler angles and axis sequence.
|
1197 |
+
|
1198 |
+
ai, aj, ak : Euler's roll, pitch and yaw angles
|
1199 |
+
axes : One of 24 axis sequences as string or encoded tuple
|
1200 |
+
|
1201 |
+
>>> q = quaternion_from_euler(1, 2, 3, 'ryxz')
|
1202 |
+
>>> numpy.allclose(q, [0.435953, 0.310622, -0.718287, 0.444435])
|
1203 |
+
True
|
1204 |
+
|
1205 |
+
"""
|
1206 |
+
try:
|
1207 |
+
firstaxis, parity, repetition, frame = _AXES2TUPLE[axes.lower()]
|
1208 |
+
except (AttributeError, KeyError):
|
1209 |
+
_TUPLE2AXES[axes] # noqa: validation
|
1210 |
+
firstaxis, parity, repetition, frame = axes
|
1211 |
+
|
1212 |
+
i = firstaxis + 1
|
1213 |
+
j = _NEXT_AXIS[i+parity-1] + 1
|
1214 |
+
k = _NEXT_AXIS[i-parity] + 1
|
1215 |
+
|
1216 |
+
if frame:
|
1217 |
+
ai, ak = ak, ai
|
1218 |
+
if parity:
|
1219 |
+
aj = -aj
|
1220 |
+
|
1221 |
+
ai /= 2.0
|
1222 |
+
aj /= 2.0
|
1223 |
+
ak /= 2.0
|
1224 |
+
ci = math.cos(ai)
|
1225 |
+
si = math.sin(ai)
|
1226 |
+
cj = math.cos(aj)
|
1227 |
+
sj = math.sin(aj)
|
1228 |
+
ck = math.cos(ak)
|
1229 |
+
sk = math.sin(ak)
|
1230 |
+
cc = ci*ck
|
1231 |
+
cs = ci*sk
|
1232 |
+
sc = si*ck
|
1233 |
+
ss = si*sk
|
1234 |
+
|
1235 |
+
q = numpy.empty((4, ))
|
1236 |
+
if repetition:
|
1237 |
+
q[0] = cj*(cc - ss)
|
1238 |
+
q[i] = cj*(cs + sc)
|
1239 |
+
q[j] = sj*(cc + ss)
|
1240 |
+
q[k] = sj*(cs - sc)
|
1241 |
+
else:
|
1242 |
+
q[0] = cj*cc + sj*ss
|
1243 |
+
q[i] = cj*sc - sj*cs
|
1244 |
+
q[j] = cj*ss + sj*cc
|
1245 |
+
q[k] = cj*cs - sj*sc
|
1246 |
+
if parity:
|
1247 |
+
q[j] *= -1.0
|
1248 |
+
|
1249 |
+
return q
|
1250 |
+
|
1251 |
+
|
1252 |
+
def quaternion_about_axis(angle, axis):
|
1253 |
+
"""Return quaternion for rotation about axis.
|
1254 |
+
|
1255 |
+
>>> q = quaternion_about_axis(0.123, [1, 0, 0])
|
1256 |
+
>>> numpy.allclose(q, [0.99810947, 0.06146124, 0, 0])
|
1257 |
+
True
|
1258 |
+
|
1259 |
+
"""
|
1260 |
+
q = numpy.array([0.0, axis[0], axis[1], axis[2]])
|
1261 |
+
qlen = vector_norm(q)
|
1262 |
+
if qlen > _EPS:
|
1263 |
+
q *= math.sin(angle/2.0) / qlen
|
1264 |
+
q[0] = math.cos(angle/2.0)
|
1265 |
+
return q
|
1266 |
+
|
1267 |
+
|
1268 |
+
def quaternion_matrix(quaternion):
|
1269 |
+
"""Return homogeneous rotation matrix from quaternion.
|
1270 |
+
|
1271 |
+
>>> M = quaternion_matrix([0.99810947, 0.06146124, 0, 0])
|
1272 |
+
>>> numpy.allclose(M, rotation_matrix(0.123, [1, 0, 0]))
|
1273 |
+
True
|
1274 |
+
>>> M = quaternion_matrix([1, 0, 0, 0])
|
1275 |
+
>>> numpy.allclose(M, numpy.identity(4))
|
1276 |
+
True
|
1277 |
+
>>> M = quaternion_matrix([0, 1, 0, 0])
|
1278 |
+
>>> numpy.allclose(M, numpy.diag([1, -1, -1, 1]))
|
1279 |
+
True
|
1280 |
+
|
1281 |
+
"""
|
1282 |
+
q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
|
1283 |
+
n = numpy.dot(q, q)
|
1284 |
+
if n < _EPS:
|
1285 |
+
return numpy.identity(4)
|
1286 |
+
q *= math.sqrt(2.0 / n)
|
1287 |
+
q = numpy.outer(q, q)
|
1288 |
+
return numpy.array([
|
1289 |
+
[1.0-q[2, 2]-q[3, 3], q[1, 2]-q[3, 0], q[1, 3]+q[2, 0], 0.0],
|
1290 |
+
[q[1, 2]+q[3, 0], 1.0-q[1, 1]-q[3, 3], q[2, 3]-q[1, 0], 0.0],
|
1291 |
+
[q[1, 3]-q[2, 0], q[2, 3]+q[1, 0], 1.0-q[1, 1]-q[2, 2], 0.0],
|
1292 |
+
[0.0, 0.0, 0.0, 1.0]])
|
1293 |
+
|
1294 |
+
|
1295 |
+
def quaternion_from_matrix(matrix, isprecise=False):
|
1296 |
+
"""Return quaternion from rotation matrix.
|
1297 |
+
|
1298 |
+
If isprecise is True, the input matrix is assumed to be a precise rotation
|
1299 |
+
matrix and a faster algorithm is used.
|
1300 |
+
|
1301 |
+
>>> q = quaternion_from_matrix(numpy.identity(4), True)
|
1302 |
+
>>> numpy.allclose(q, [1, 0, 0, 0])
|
1303 |
+
True
|
1304 |
+
>>> q = quaternion_from_matrix(numpy.diag([1, -1, -1, 1]))
|
1305 |
+
>>> numpy.allclose(q, [0, 1, 0, 0]) or numpy.allclose(q, [0, -1, 0, 0])
|
1306 |
+
True
|
1307 |
+
>>> R = rotation_matrix(0.123, (1, 2, 3))
|
1308 |
+
>>> q = quaternion_from_matrix(R, True)
|
1309 |
+
>>> numpy.allclose(q, [0.9981095, 0.0164262, 0.0328524, 0.0492786])
|
1310 |
+
True
|
1311 |
+
>>> R = [[-0.545, 0.797, 0.260, 0], [0.733, 0.603, -0.313, 0],
|
1312 |
+
... [-0.407, 0.021, -0.913, 0], [0, 0, 0, 1]]
|
1313 |
+
>>> q = quaternion_from_matrix(R)
|
1314 |
+
>>> numpy.allclose(q, [0.19069, 0.43736, 0.87485, -0.083611])
|
1315 |
+
True
|
1316 |
+
>>> R = [[0.395, 0.362, 0.843, 0], [-0.626, 0.796, -0.056, 0],
|
1317 |
+
... [-0.677, -0.498, 0.529, 0], [0, 0, 0, 1]]
|
1318 |
+
>>> q = quaternion_from_matrix(R)
|
1319 |
+
>>> numpy.allclose(q, [0.82336615, -0.13610694, 0.46344705, -0.29792603])
|
1320 |
+
True
|
1321 |
+
>>> R = random_rotation_matrix()
|
1322 |
+
>>> q = quaternion_from_matrix(R)
|
1323 |
+
>>> is_same_transform(R, quaternion_matrix(q))
|
1324 |
+
True
|
1325 |
+
>>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
|
1326 |
+
... quaternion_from_matrix(R, isprecise=True))
|
1327 |
+
True
|
1328 |
+
>>> R = euler_matrix(0.0, 0.0, numpy.pi/2.0)
|
1329 |
+
>>> is_same_quaternion(quaternion_from_matrix(R, isprecise=False),
|
1330 |
+
... quaternion_from_matrix(R, isprecise=True))
|
1331 |
+
True
|
1332 |
+
|
1333 |
+
"""
|
1334 |
+
M = numpy.array(matrix, dtype=numpy.float64, copy=False)[:4, :4]
|
1335 |
+
if isprecise:
|
1336 |
+
q = numpy.empty((4, ))
|
1337 |
+
t = numpy.trace(M)
|
1338 |
+
if t > M[3, 3]:
|
1339 |
+
q[0] = t
|
1340 |
+
q[3] = M[1, 0] - M[0, 1]
|
1341 |
+
q[2] = M[0, 2] - M[2, 0]
|
1342 |
+
q[1] = M[2, 1] - M[1, 2]
|
1343 |
+
else:
|
1344 |
+
i, j, k = 0, 1, 2
|
1345 |
+
if M[1, 1] > M[0, 0]:
|
1346 |
+
i, j, k = 1, 2, 0
|
1347 |
+
if M[2, 2] > M[i, i]:
|
1348 |
+
i, j, k = 2, 0, 1
|
1349 |
+
t = M[i, i] - (M[j, j] + M[k, k]) + M[3, 3]
|
1350 |
+
q[i] = t
|
1351 |
+
q[j] = M[i, j] + M[j, i]
|
1352 |
+
q[k] = M[k, i] + M[i, k]
|
1353 |
+
q[3] = M[k, j] - M[j, k]
|
1354 |
+
q = q[[3, 0, 1, 2]]
|
1355 |
+
q *= 0.5 / math.sqrt(t * M[3, 3])
|
1356 |
+
else:
|
1357 |
+
m00 = M[0, 0]
|
1358 |
+
m01 = M[0, 1]
|
1359 |
+
m02 = M[0, 2]
|
1360 |
+
m10 = M[1, 0]
|
1361 |
+
m11 = M[1, 1]
|
1362 |
+
m12 = M[1, 2]
|
1363 |
+
m20 = M[2, 0]
|
1364 |
+
m21 = M[2, 1]
|
1365 |
+
m22 = M[2, 2]
|
1366 |
+
# symmetric matrix K
|
1367 |
+
K = numpy.array([[m00-m11-m22, 0.0, 0.0, 0.0],
|
1368 |
+
[m01+m10, m11-m00-m22, 0.0, 0.0],
|
1369 |
+
[m02+m20, m12+m21, m22-m00-m11, 0.0],
|
1370 |
+
[m21-m12, m02-m20, m10-m01, m00+m11+m22]])
|
1371 |
+
K /= 3.0
|
1372 |
+
# quaternion is eigenvector of K that corresponds to largest eigenvalue
|
1373 |
+
w, V = numpy.linalg.eigh(K)
|
1374 |
+
q = V[[3, 0, 1, 2], numpy.argmax(w)]
|
1375 |
+
if q[0] < 0.0:
|
1376 |
+
numpy.negative(q, q)
|
1377 |
+
return q
|
1378 |
+
|
1379 |
+
|
1380 |
+
def quaternion_multiply(quaternion1, quaternion0):
|
1381 |
+
"""Return multiplication of two quaternions.
|
1382 |
+
|
1383 |
+
>>> q = quaternion_multiply([4, 1, -2, 3], [8, -5, 6, 7])
|
1384 |
+
>>> numpy.allclose(q, [28, -44, -14, 48])
|
1385 |
+
True
|
1386 |
+
|
1387 |
+
"""
|
1388 |
+
w0, x0, y0, z0 = quaternion0
|
1389 |
+
w1, x1, y1, z1 = quaternion1
|
1390 |
+
return numpy.array([
|
1391 |
+
-x1*x0 - y1*y0 - z1*z0 + w1*w0,
|
1392 |
+
x1*w0 + y1*z0 - z1*y0 + w1*x0,
|
1393 |
+
-x1*z0 + y1*w0 + z1*x0 + w1*y0,
|
1394 |
+
x1*y0 - y1*x0 + z1*w0 + w1*z0], dtype=numpy.float64)
|
1395 |
+
|
1396 |
+
|
1397 |
+
def quaternion_conjugate(quaternion):
|
1398 |
+
"""Return conjugate of quaternion.
|
1399 |
+
|
1400 |
+
>>> q0 = random_quaternion()
|
1401 |
+
>>> q1 = quaternion_conjugate(q0)
|
1402 |
+
>>> q1[0] == q0[0] and all(q1[1:] == -q0[1:])
|
1403 |
+
True
|
1404 |
+
|
1405 |
+
"""
|
1406 |
+
q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
|
1407 |
+
numpy.negative(q[1:], q[1:])
|
1408 |
+
return q
|
1409 |
+
|
1410 |
+
|
1411 |
+
def quaternion_inverse(quaternion):
|
1412 |
+
"""Return inverse of quaternion.
|
1413 |
+
|
1414 |
+
>>> q0 = random_quaternion()
|
1415 |
+
>>> q1 = quaternion_inverse(q0)
|
1416 |
+
>>> numpy.allclose(quaternion_multiply(q0, q1), [1, 0, 0, 0])
|
1417 |
+
True
|
1418 |
+
|
1419 |
+
"""
|
1420 |
+
q = numpy.array(quaternion, dtype=numpy.float64, copy=True)
|
1421 |
+
numpy.negative(q[1:], q[1:])
|
1422 |
+
return q / numpy.dot(q, q)
|
1423 |
+
|
1424 |
+
|
1425 |
+
def quaternion_real(quaternion):
|
1426 |
+
"""Return real part of quaternion.
|
1427 |
+
|
1428 |
+
>>> quaternion_real([3, 0, 1, 2])
|
1429 |
+
3.0
|
1430 |
+
|
1431 |
+
"""
|
1432 |
+
return float(quaternion[0])
|
1433 |
+
|
1434 |
+
|
1435 |
+
def quaternion_imag(quaternion):
|
1436 |
+
"""Return imaginary part of quaternion.
|
1437 |
+
|
1438 |
+
>>> quaternion_imag([3, 0, 1, 2])
|
1439 |
+
array([ 0., 1., 2.])
|
1440 |
+
|
1441 |
+
"""
|
1442 |
+
return numpy.array(quaternion[1:4], dtype=numpy.float64, copy=True)
|
1443 |
+
|
1444 |
+
|
1445 |
+
def quaternion_slerp(quat0, quat1, fraction, spin=0, shortestpath=True):
|
1446 |
+
"""Return spherical linear interpolation between two quaternions.
|
1447 |
+
|
1448 |
+
>>> q0 = random_quaternion()
|
1449 |
+
>>> q1 = random_quaternion()
|
1450 |
+
>>> q = quaternion_slerp(q0, q1, 0)
|
1451 |
+
>>> numpy.allclose(q, q0)
|
1452 |
+
True
|
1453 |
+
>>> q = quaternion_slerp(q0, q1, 1, 1)
|
1454 |
+
>>> numpy.allclose(q, q1)
|
1455 |
+
True
|
1456 |
+
>>> q = quaternion_slerp(q0, q1, 0.5)
|
1457 |
+
>>> angle = math.acos(numpy.dot(q0, q))
|
1458 |
+
>>> numpy.allclose(2, math.acos(numpy.dot(q0, q1)) / angle) or \
|
1459 |
+
numpy.allclose(2, math.acos(-numpy.dot(q0, q1)) / angle)
|
1460 |
+
True
|
1461 |
+
|
1462 |
+
"""
|
1463 |
+
q0 = unit_vector(quat0[:4])
|
1464 |
+
q1 = unit_vector(quat1[:4])
|
1465 |
+
if fraction == 0.0:
|
1466 |
+
return q0
|
1467 |
+
elif fraction == 1.0:
|
1468 |
+
return q1
|
1469 |
+
d = numpy.dot(q0, q1)
|
1470 |
+
if abs(abs(d) - 1.0) < _EPS:
|
1471 |
+
return q0
|
1472 |
+
if shortestpath and d < 0.0:
|
1473 |
+
# invert rotation
|
1474 |
+
d = -d
|
1475 |
+
numpy.negative(q1, q1)
|
1476 |
+
angle = math.acos(d) + spin * math.pi
|
1477 |
+
if abs(angle) < _EPS:
|
1478 |
+
return q0
|
1479 |
+
isin = 1.0 / math.sin(angle)
|
1480 |
+
q0 *= math.sin((1.0 - fraction) * angle) * isin
|
1481 |
+
q1 *= math.sin(fraction * angle) * isin
|
1482 |
+
q0 += q1
|
1483 |
+
return q0
|
1484 |
+
|
1485 |
+
|
1486 |
+
def random_quaternion(rand=None):
|
1487 |
+
"""Return uniform random unit quaternion.
|
1488 |
+
|
1489 |
+
rand: array like or None
|
1490 |
+
Three independent random variables that are uniformly distributed
|
1491 |
+
between 0 and 1.
|
1492 |
+
|
1493 |
+
>>> q = random_quaternion()
|
1494 |
+
>>> numpy.allclose(1, vector_norm(q))
|
1495 |
+
True
|
1496 |
+
>>> q = random_quaternion(numpy.random.random(3))
|
1497 |
+
>>> len(q.shape), q.shape[0]==4
|
1498 |
+
(1, True)
|
1499 |
+
|
1500 |
+
"""
|
1501 |
+
if rand is None:
|
1502 |
+
rand = numpy.random.rand(3)
|
1503 |
+
else:
|
1504 |
+
assert len(rand) == 3
|
1505 |
+
r1 = numpy.sqrt(1.0 - rand[0])
|
1506 |
+
r2 = numpy.sqrt(rand[0])
|
1507 |
+
pi2 = math.pi * 2.0
|
1508 |
+
t1 = pi2 * rand[1]
|
1509 |
+
t2 = pi2 * rand[2]
|
1510 |
+
return numpy.array([numpy.cos(t2)*r2, numpy.sin(t1)*r1,
|
1511 |
+
numpy.cos(t1)*r1, numpy.sin(t2)*r2])
|
1512 |
+
|
1513 |
+
|
1514 |
+
def random_rotation_matrix(rand=None):
|
1515 |
+
"""Return uniform random rotation matrix.
|
1516 |
+
|
1517 |
+
rand: array like
|
1518 |
+
Three independent random variables that are uniformly distributed
|
1519 |
+
between 0 and 1 for each returned quaternion.
|
1520 |
+
|
1521 |
+
>>> R = random_rotation_matrix()
|
1522 |
+
>>> numpy.allclose(numpy.dot(R.T, R), numpy.identity(4))
|
1523 |
+
True
|
1524 |
+
|
1525 |
+
"""
|
1526 |
+
return quaternion_matrix(random_quaternion(rand))
|
1527 |
+
|
1528 |
+
|
1529 |
+
class Arcball(object):
|
1530 |
+
"""Virtual Trackball Control.
|
1531 |
+
|
1532 |
+
>>> ball = Arcball()
|
1533 |
+
>>> ball = Arcball(initial=numpy.identity(4))
|
1534 |
+
>>> ball.place([320, 320], 320)
|
1535 |
+
>>> ball.down([500, 250])
|
1536 |
+
>>> ball.drag([475, 275])
|
1537 |
+
>>> R = ball.matrix()
|
1538 |
+
>>> numpy.allclose(numpy.sum(R), 3.90583455)
|
1539 |
+
True
|
1540 |
+
>>> ball = Arcball(initial=[1, 0, 0, 0])
|
1541 |
+
>>> ball.place([320, 320], 320)
|
1542 |
+
>>> ball.setaxes([1, 1, 0], [-1, 1, 0])
|
1543 |
+
>>> ball.constrain = True
|
1544 |
+
>>> ball.down([400, 200])
|
1545 |
+
>>> ball.drag([200, 400])
|
1546 |
+
>>> R = ball.matrix()
|
1547 |
+
>>> numpy.allclose(numpy.sum(R), 0.2055924)
|
1548 |
+
True
|
1549 |
+
>>> ball.next()
|
1550 |
+
|
1551 |
+
"""
|
1552 |
+
|
1553 |
+
def __init__(self, initial=None):
|
1554 |
+
"""Initialize virtual trackball control.
|
1555 |
+
|
1556 |
+
initial : quaternion or rotation matrix
|
1557 |
+
|
1558 |
+
"""
|
1559 |
+
self._axis = None
|
1560 |
+
self._axes = None
|
1561 |
+
self._radius = 1.0
|
1562 |
+
self._center = [0.0, 0.0]
|
1563 |
+
self._vdown = numpy.array([0.0, 0.0, 1.0])
|
1564 |
+
self._constrain = False
|
1565 |
+
if initial is None:
|
1566 |
+
self._qdown = numpy.array([1.0, 0.0, 0.0, 0.0])
|
1567 |
+
else:
|
1568 |
+
initial = numpy.array(initial, dtype=numpy.float64)
|
1569 |
+
if initial.shape == (4, 4):
|
1570 |
+
self._qdown = quaternion_from_matrix(initial)
|
1571 |
+
elif initial.shape == (4, ):
|
1572 |
+
initial /= vector_norm(initial)
|
1573 |
+
self._qdown = initial
|
1574 |
+
else:
|
1575 |
+
raise ValueError("initial not a quaternion or matrix")
|
1576 |
+
self._qnow = self._qpre = self._qdown
|
1577 |
+
|
1578 |
+
def place(self, center, radius):
|
1579 |
+
"""Place Arcball, e.g. when window size changes.
|
1580 |
+
|
1581 |
+
center : sequence[2]
|
1582 |
+
Window coordinates of trackball center.
|
1583 |
+
radius : float
|
1584 |
+
Radius of trackball in window coordinates.
|
1585 |
+
|
1586 |
+
"""
|
1587 |
+
self._radius = float(radius)
|
1588 |
+
self._center[0] = center[0]
|
1589 |
+
self._center[1] = center[1]
|
1590 |
+
|
1591 |
+
def setaxes(self, *axes):
|
1592 |
+
"""Set axes to constrain rotations."""
|
1593 |
+
if axes is None:
|
1594 |
+
self._axes = None
|
1595 |
+
else:
|
1596 |
+
self._axes = [unit_vector(axis) for axis in axes]
|
1597 |
+
|
1598 |
+
@property
|
1599 |
+
def constrain(self):
|
1600 |
+
"""Return state of constrain to axis mode."""
|
1601 |
+
return self._constrain
|
1602 |
+
|
1603 |
+
@constrain.setter
|
1604 |
+
def constrain(self, value):
|
1605 |
+
"""Set state of constrain to axis mode."""
|
1606 |
+
self._constrain = bool(value)
|
1607 |
+
|
1608 |
+
def down(self, point):
|
1609 |
+
"""Set initial cursor window coordinates and pick constrain-axis."""
|
1610 |
+
self._vdown = arcball_map_to_sphere(point, self._center, self._radius)
|
1611 |
+
self._qdown = self._qpre = self._qnow
|
1612 |
+
if self._constrain and self._axes is not None:
|
1613 |
+
self._axis = arcball_nearest_axis(self._vdown, self._axes)
|
1614 |
+
self._vdown = arcball_constrain_to_axis(self._vdown, self._axis)
|
1615 |
+
else:
|
1616 |
+
self._axis = None
|
1617 |
+
|
1618 |
+
def drag(self, point):
|
1619 |
+
"""Update current cursor window coordinates."""
|
1620 |
+
vnow = arcball_map_to_sphere(point, self._center, self._radius)
|
1621 |
+
if self._axis is not None:
|
1622 |
+
vnow = arcball_constrain_to_axis(vnow, self._axis)
|
1623 |
+
self._qpre = self._qnow
|
1624 |
+
t = numpy.cross(self._vdown, vnow)
|
1625 |
+
if numpy.dot(t, t) < _EPS:
|
1626 |
+
self._qnow = self._qdown
|
1627 |
+
else:
|
1628 |
+
q = [numpy.dot(self._vdown, vnow), t[0], t[1], t[2]]
|
1629 |
+
self._qnow = quaternion_multiply(q, self._qdown)
|
1630 |
+
|
1631 |
+
def next(self, acceleration=0.0):
|
1632 |
+
"""Continue rotation in direction of last drag."""
|
1633 |
+
q = quaternion_slerp(self._qpre, self._qnow, 2.0+acceleration, False)
|
1634 |
+
self._qpre, self._qnow = self._qnow, q
|
1635 |
+
|
1636 |
+
def matrix(self):
|
1637 |
+
"""Return homogeneous rotation matrix."""
|
1638 |
+
return quaternion_matrix(self._qnow)
|
1639 |
+
|
1640 |
+
|
1641 |
+
def arcball_map_to_sphere(point, center, radius):
|
1642 |
+
"""Return unit sphere coordinates from window coordinates."""
|
1643 |
+
v0 = (point[0] - center[0]) / radius
|
1644 |
+
v1 = (center[1] - point[1]) / radius
|
1645 |
+
n = v0*v0 + v1*v1
|
1646 |
+
if n > 1.0:
|
1647 |
+
# position outside of sphere
|
1648 |
+
n = math.sqrt(n)
|
1649 |
+
return numpy.array([v0/n, v1/n, 0.0])
|
1650 |
+
else:
|
1651 |
+
return numpy.array([v0, v1, math.sqrt(1.0 - n)])
|
1652 |
+
|
1653 |
+
|
1654 |
+
def arcball_constrain_to_axis(point, axis):
|
1655 |
+
"""Return sphere point perpendicular to axis."""
|
1656 |
+
v = numpy.array(point, dtype=numpy.float64, copy=True)
|
1657 |
+
a = numpy.array(axis, dtype=numpy.float64, copy=True)
|
1658 |
+
v -= a * numpy.dot(a, v) # on plane
|
1659 |
+
n = vector_norm(v)
|
1660 |
+
if n > _EPS:
|
1661 |
+
if v[2] < 0.0:
|
1662 |
+
numpy.negative(v, v)
|
1663 |
+
v /= n
|
1664 |
+
return v
|
1665 |
+
if a[2] == 1.0:
|
1666 |
+
return numpy.array([1.0, 0.0, 0.0])
|
1667 |
+
return unit_vector([-a[1], a[0], 0.0])
|
1668 |
+
|
1669 |
+
|
1670 |
+
def arcball_nearest_axis(point, axes):
|
1671 |
+
"""Return axis, which arc is nearest to point."""
|
1672 |
+
point = numpy.array(point, dtype=numpy.float64, copy=False)
|
1673 |
+
nearest = None
|
1674 |
+
mx = -1.0
|
1675 |
+
for axis in axes:
|
1676 |
+
t = numpy.dot(arcball_constrain_to_axis(point, axis), point)
|
1677 |
+
if t > mx:
|
1678 |
+
nearest = axis
|
1679 |
+
mx = t
|
1680 |
+
return nearest
|
1681 |
+
|
1682 |
+
|
1683 |
+
# epsilon for testing whether a number is close to zero
|
1684 |
+
_EPS = numpy.finfo(float).eps * 4.0
|
1685 |
+
|
1686 |
+
# axis sequences for Euler angles
|
1687 |
+
_NEXT_AXIS = [1, 2, 0, 1]
|
1688 |
+
|
1689 |
+
# map axes strings to/from tuples of inner axis, parity, repetition, frame
|
1690 |
+
_AXES2TUPLE = {
|
1691 |
+
'sxyz': (0, 0, 0, 0), 'sxyx': (0, 0, 1, 0), 'sxzy': (0, 1, 0, 0),
|
1692 |
+
'sxzx': (0, 1, 1, 0), 'syzx': (1, 0, 0, 0), 'syzy': (1, 0, 1, 0),
|
1693 |
+
'syxz': (1, 1, 0, 0), 'syxy': (1, 1, 1, 0), 'szxy': (2, 0, 0, 0),
|
1694 |
+
'szxz': (2, 0, 1, 0), 'szyx': (2, 1, 0, 0), 'szyz': (2, 1, 1, 0),
|
1695 |
+
'rzyx': (0, 0, 0, 1), 'rxyx': (0, 0, 1, 1), 'ryzx': (0, 1, 0, 1),
|
1696 |
+
'rxzx': (0, 1, 1, 1), 'rxzy': (1, 0, 0, 1), 'ryzy': (1, 0, 1, 1),
|
1697 |
+
'rzxy': (1, 1, 0, 1), 'ryxy': (1, 1, 1, 1), 'ryxz': (2, 0, 0, 1),
|
1698 |
+
'rzxz': (2, 0, 1, 1), 'rxyz': (2, 1, 0, 1), 'rzyz': (2, 1, 1, 1)}
|
1699 |
+
|
1700 |
+
_TUPLE2AXES = dict((v, k) for k, v in _AXES2TUPLE.items())
|
1701 |
+
|
1702 |
+
|
1703 |
+
def vector_norm(data, axis=None, out=None):
|
1704 |
+
"""Return length, i.e. Euclidean norm, of ndarray along axis.
|
1705 |
+
|
1706 |
+
>>> v = numpy.random.random(3)
|
1707 |
+
>>> n = vector_norm(v)
|
1708 |
+
>>> numpy.allclose(n, numpy.linalg.norm(v))
|
1709 |
+
True
|
1710 |
+
>>> v = numpy.random.rand(6, 5, 3)
|
1711 |
+
>>> n = vector_norm(v, axis=-1)
|
1712 |
+
>>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=2)))
|
1713 |
+
True
|
1714 |
+
>>> n = vector_norm(v, axis=1)
|
1715 |
+
>>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
|
1716 |
+
True
|
1717 |
+
>>> v = numpy.random.rand(5, 4, 3)
|
1718 |
+
>>> n = numpy.empty((5, 3))
|
1719 |
+
>>> vector_norm(v, axis=1, out=n)
|
1720 |
+
>>> numpy.allclose(n, numpy.sqrt(numpy.sum(v*v, axis=1)))
|
1721 |
+
True
|
1722 |
+
>>> vector_norm([])
|
1723 |
+
0.0
|
1724 |
+
>>> vector_norm([1])
|
1725 |
+
1.0
|
1726 |
+
|
1727 |
+
"""
|
1728 |
+
data = numpy.array(data, dtype=numpy.float64, copy=True)
|
1729 |
+
if out is None:
|
1730 |
+
if data.ndim == 1:
|
1731 |
+
return math.sqrt(numpy.dot(data, data))
|
1732 |
+
data *= data
|
1733 |
+
out = numpy.atleast_1d(numpy.sum(data, axis=axis))
|
1734 |
+
numpy.sqrt(out, out)
|
1735 |
+
return out
|
1736 |
+
else:
|
1737 |
+
data *= data
|
1738 |
+
numpy.sum(data, axis=axis, out=out)
|
1739 |
+
numpy.sqrt(out, out)
|
1740 |
+
|
1741 |
+
|
1742 |
+
def unit_vector(data, axis=None, out=None):
|
1743 |
+
"""Return ndarray normalized by length, i.e. Euclidean norm, along axis.
|
1744 |
+
|
1745 |
+
>>> v0 = numpy.random.random(3)
|
1746 |
+
>>> v1 = unit_vector(v0)
|
1747 |
+
>>> numpy.allclose(v1, v0 / numpy.linalg.norm(v0))
|
1748 |
+
True
|
1749 |
+
>>> v0 = numpy.random.rand(5, 4, 3)
|
1750 |
+
>>> v1 = unit_vector(v0, axis=-1)
|
1751 |
+
>>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=2)), 2)
|
1752 |
+
>>> numpy.allclose(v1, v2)
|
1753 |
+
True
|
1754 |
+
>>> v1 = unit_vector(v0, axis=1)
|
1755 |
+
>>> v2 = v0 / numpy.expand_dims(numpy.sqrt(numpy.sum(v0*v0, axis=1)), 1)
|
1756 |
+
>>> numpy.allclose(v1, v2)
|
1757 |
+
True
|
1758 |
+
>>> v1 = numpy.empty((5, 4, 3))
|
1759 |
+
>>> unit_vector(v0, axis=1, out=v1)
|
1760 |
+
>>> numpy.allclose(v1, v2)
|
1761 |
+
True
|
1762 |
+
>>> list(unit_vector([]))
|
1763 |
+
[]
|
1764 |
+
>>> list(unit_vector([1]))
|
1765 |
+
[1.0]
|
1766 |
+
|
1767 |
+
"""
|
1768 |
+
if out is None:
|
1769 |
+
data = numpy.array(data, dtype=numpy.float64, copy=True)
|
1770 |
+
if data.ndim == 1:
|
1771 |
+
data /= math.sqrt(numpy.dot(data, data))
|
1772 |
+
return data
|
1773 |
+
else:
|
1774 |
+
if out is not data:
|
1775 |
+
out[:] = numpy.array(data, copy=False)
|
1776 |
+
data = out
|
1777 |
+
length = numpy.atleast_1d(numpy.sum(data*data, axis))
|
1778 |
+
numpy.sqrt(length, length)
|
1779 |
+
if axis is not None:
|
1780 |
+
length = numpy.expand_dims(length, axis)
|
1781 |
+
data /= length
|
1782 |
+
if out is None:
|
1783 |
+
return data
|
1784 |
+
|
1785 |
+
|
1786 |
+
def random_vector(size):
|
1787 |
+
"""Return array of random doubles in the half-open interval [0.0, 1.0).
|
1788 |
+
|
1789 |
+
>>> v = random_vector(10000)
|
1790 |
+
>>> numpy.all(v >= 0) and numpy.all(v < 1)
|
1791 |
+
True
|
1792 |
+
>>> v0 = random_vector(10)
|
1793 |
+
>>> v1 = random_vector(10)
|
1794 |
+
>>> numpy.any(v0 == v1)
|
1795 |
+
False
|
1796 |
+
|
1797 |
+
"""
|
1798 |
+
return numpy.random.random(size)
|
1799 |
+
|
1800 |
+
|
1801 |
+
def vector_product(v0, v1, axis=0):
|
1802 |
+
"""Return vector perpendicular to vectors.
|
1803 |
+
|
1804 |
+
>>> v = vector_product([2, 0, 0], [0, 3, 0])
|
1805 |
+
>>> numpy.allclose(v, [0, 0, 6])
|
1806 |
+
True
|
1807 |
+
>>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]]
|
1808 |
+
>>> v1 = [[3], [0], [0]]
|
1809 |
+
>>> v = vector_product(v0, v1)
|
1810 |
+
>>> numpy.allclose(v, [[0, 0, 0, 0], [0, 0, 6, 6], [0, -6, 0, -6]])
|
1811 |
+
True
|
1812 |
+
>>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]]
|
1813 |
+
>>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]]
|
1814 |
+
>>> v = vector_product(v0, v1, axis=1)
|
1815 |
+
>>> numpy.allclose(v, [[0, 0, 6], [0, -6, 0], [6, 0, 0], [0, -6, 6]])
|
1816 |
+
True
|
1817 |
+
|
1818 |
+
"""
|
1819 |
+
return numpy.cross(v0, v1, axis=axis)
|
1820 |
+
|
1821 |
+
|
1822 |
+
def angle_between_vectors(v0, v1, directed=True, axis=0):
|
1823 |
+
"""Return angle between vectors.
|
1824 |
+
|
1825 |
+
If directed is False, the input vectors are interpreted as undirected axes,
|
1826 |
+
i.e. the maximum angle is pi/2.
|
1827 |
+
|
1828 |
+
>>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3])
|
1829 |
+
>>> numpy.allclose(a, math.pi)
|
1830 |
+
True
|
1831 |
+
>>> a = angle_between_vectors([1, -2, 3], [-1, 2, -3], directed=False)
|
1832 |
+
>>> numpy.allclose(a, 0)
|
1833 |
+
True
|
1834 |
+
>>> v0 = [[2, 0, 0, 2], [0, 2, 0, 2], [0, 0, 2, 2]]
|
1835 |
+
>>> v1 = [[3], [0], [0]]
|
1836 |
+
>>> a = angle_between_vectors(v0, v1)
|
1837 |
+
>>> numpy.allclose(a, [0, 1.5708, 1.5708, 0.95532])
|
1838 |
+
True
|
1839 |
+
>>> v0 = [[2, 0, 0], [2, 0, 0], [0, 2, 0], [2, 0, 0]]
|
1840 |
+
>>> v1 = [[0, 3, 0], [0, 0, 3], [0, 0, 3], [3, 3, 3]]
|
1841 |
+
>>> a = angle_between_vectors(v0, v1, axis=1)
|
1842 |
+
>>> numpy.allclose(a, [1.5708, 1.5708, 1.5708, 0.95532])
|
1843 |
+
True
|
1844 |
+
|
1845 |
+
"""
|
1846 |
+
v0 = numpy.array(v0, dtype=numpy.float64, copy=False)
|
1847 |
+
v1 = numpy.array(v1, dtype=numpy.float64, copy=False)
|
1848 |
+
dot = numpy.sum(v0 * v1, axis=axis)
|
1849 |
+
dot /= vector_norm(v0, axis=axis) * vector_norm(v1, axis=axis)
|
1850 |
+
dot = numpy.clip(dot, -1.0, 1.0)
|
1851 |
+
return numpy.arccos(dot if directed else numpy.fabs(dot))
|
1852 |
+
|
1853 |
+
|
1854 |
+
def inverse_matrix(matrix):
|
1855 |
+
"""Return inverse of square transformation matrix.
|
1856 |
+
|
1857 |
+
>>> M0 = random_rotation_matrix()
|
1858 |
+
>>> M1 = inverse_matrix(M0.T)
|
1859 |
+
>>> numpy.allclose(M1, numpy.linalg.inv(M0.T))
|
1860 |
+
True
|
1861 |
+
>>> for size in range(1, 7):
|
1862 |
+
... M0 = numpy.random.rand(size, size)
|
1863 |
+
... M1 = inverse_matrix(M0)
|
1864 |
+
... if not numpy.allclose(M1, numpy.linalg.inv(M0)): print(size)
|
1865 |
+
|
1866 |
+
"""
|
1867 |
+
return numpy.linalg.inv(matrix)
|
1868 |
+
|
1869 |
+
|
1870 |
+
def concatenate_matrices(*matrices):
|
1871 |
+
"""Return concatenation of series of transformation matrices.
|
1872 |
+
|
1873 |
+
>>> M = numpy.random.rand(16).reshape((4, 4)) - 0.5
|
1874 |
+
>>> numpy.allclose(M, concatenate_matrices(M))
|
1875 |
+
True
|
1876 |
+
>>> numpy.allclose(numpy.dot(M, M.T), concatenate_matrices(M, M.T))
|
1877 |
+
True
|
1878 |
+
|
1879 |
+
"""
|
1880 |
+
M = numpy.identity(4)
|
1881 |
+
for i in matrices:
|
1882 |
+
M = numpy.dot(M, i)
|
1883 |
+
return M
|
1884 |
+
|
1885 |
+
|
1886 |
+
def is_same_transform(matrix0, matrix1):
|
1887 |
+
"""Return True if two matrices perform same transformation.
|
1888 |
+
|
1889 |
+
>>> is_same_transform(numpy.identity(4), numpy.identity(4))
|
1890 |
+
True
|
1891 |
+
>>> is_same_transform(numpy.identity(4), random_rotation_matrix())
|
1892 |
+
False
|
1893 |
+
|
1894 |
+
"""
|
1895 |
+
matrix0 = numpy.array(matrix0, dtype=numpy.float64, copy=True)
|
1896 |
+
matrix0 /= matrix0[3, 3]
|
1897 |
+
matrix1 = numpy.array(matrix1, dtype=numpy.float64, copy=True)
|
1898 |
+
matrix1 /= matrix1[3, 3]
|
1899 |
+
return numpy.allclose(matrix0, matrix1)
|
1900 |
+
|
1901 |
+
|
1902 |
+
def is_same_quaternion(q0, q1):
|
1903 |
+
"""Return True if two quaternions are equal."""
|
1904 |
+
q0 = numpy.array(q0)
|
1905 |
+
q1 = numpy.array(q1)
|
1906 |
+
return numpy.allclose(q0, q1) or numpy.allclose(q0, -q1)
|
1907 |
+
|
1908 |
+
|
1909 |
+
def _import_module(name, package=None, warn=True, postfix='_py', ignore='_'):
|
1910 |
+
"""Try import all public attributes from module into global namespace.
|
1911 |
+
|
1912 |
+
Existing attributes with name clashes are renamed with prefix.
|
1913 |
+
Attributes starting with underscore are ignored by default.
|
1914 |
+
|
1915 |
+
Return True on successful import.
|
1916 |
+
|
1917 |
+
"""
|
1918 |
+
import warnings
|
1919 |
+
from importlib import import_module
|
1920 |
+
try:
|
1921 |
+
if not package:
|
1922 |
+
module = import_module(name)
|
1923 |
+
else:
|
1924 |
+
module = import_module('.' + name, package=package)
|
1925 |
+
except ImportError as err:
|
1926 |
+
if warn:
|
1927 |
+
warnings.warn(str(err))
|
1928 |
+
else:
|
1929 |
+
for attr in dir(module):
|
1930 |
+
if ignore and attr.startswith(ignore):
|
1931 |
+
continue
|
1932 |
+
if postfix:
|
1933 |
+
if attr in globals():
|
1934 |
+
globals()[attr + postfix] = globals()[attr]
|
1935 |
+
elif warn:
|
1936 |
+
warnings.warn('no Python implementation of ' + attr)
|
1937 |
+
globals()[attr] = getattr(module, attr)
|
1938 |
+
return True
|
1939 |
+
|
1940 |
+
|
1941 |
+
_import_module('_transformations', __package__, warn=False)
|
1942 |
+
|
1943 |
+
|
1944 |
+
if __name__ == '__main__':
|
1945 |
+
import doctest
|
1946 |
+
import random # noqa: used in doctests
|
1947 |
+
try:
|
1948 |
+
numpy.set_printoptions(suppress=True, precision=5, legacy='1.13')
|
1949 |
+
except TypeError:
|
1950 |
+
numpy.set_printoptions(suppress=True, precision=5)
|
1951 |
+
doctest.testmod()
|
third_party/COTR/COTR/utils/constants.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
DEFAULT_PRECISION = 'float32'
|
2 |
+
MAX_SIZE = 256
|
3 |
+
VALID_NN_OVERLAPPING_THRESH = 0.1
|
third_party/COTR/COTR/utils/debug_utils.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def embed_breakpoint(debug_info='', terminate=True):
|
2 |
+
print('\nyou are inside a break point')
|
3 |
+
if debug_info:
|
4 |
+
print('debug info: {0}'.format(debug_info))
|
5 |
+
print('')
|
6 |
+
embedding = ('import IPython\n'
|
7 |
+
'import matplotlib.pyplot as plt\n'
|
8 |
+
'IPython.embed()\n'
|
9 |
+
)
|
10 |
+
if terminate:
|
11 |
+
embedding += (
|
12 |
+
'assert 0, \'force termination\'\n'
|
13 |
+
)
|
14 |
+
|
15 |
+
return embedding
|
third_party/COTR/COTR/utils/utils.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import smtplib
|
3 |
+
import ssl
|
4 |
+
from collections import namedtuple
|
5 |
+
|
6 |
+
from COTR.utils import debug_utils
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import cv2
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import PIL
|
13 |
+
|
14 |
+
|
15 |
+
'''
|
16 |
+
ImagePatch: patch: patch content, np array or None
|
17 |
+
x: left bound in original resolution
|
18 |
+
y: upper bound in original resolution
|
19 |
+
w: width of patch
|
20 |
+
h: height of patch
|
21 |
+
ow: width of original resolution
|
22 |
+
oh: height of original resolution
|
23 |
+
'''
|
24 |
+
ImagePatch = namedtuple('ImagePatch', ['patch', 'x', 'y', 'w', 'h', 'ow', 'oh'])
|
25 |
+
Point3D = namedtuple("Point3D", ["id", "arr_idx", "image_ids"])
|
26 |
+
Point2D = namedtuple("Point2D", ["id_3d", "xy"])
|
27 |
+
|
28 |
+
|
29 |
+
class CropCamConfig():
|
30 |
+
def __init__(self, x, y, w, h, out_w, out_h, orig_w, orig_h):
|
31 |
+
'''
|
32 |
+
xy: left upper corner
|
33 |
+
'''
|
34 |
+
# assert x > 0 and x < orig_w
|
35 |
+
# assert y > 0 and y < orig_h
|
36 |
+
# assert w < orig_w and h < orig_h
|
37 |
+
# assert x - w / 2 > 0 and x + w / 2 < orig_w
|
38 |
+
# assert y - h / 2 > 0 and y + h / 2 < orig_h
|
39 |
+
# assert h / w == out_h / out_w
|
40 |
+
self.x = x
|
41 |
+
self.y = y
|
42 |
+
self.w = w
|
43 |
+
self.h = h
|
44 |
+
self.out_w = out_w
|
45 |
+
self.out_h = out_h
|
46 |
+
self.orig_w = orig_w
|
47 |
+
self.orig_h = orig_h
|
48 |
+
|
49 |
+
def __str__(self):
|
50 |
+
out = f'original image size(h,w): [{self.orig_h}, {self.orig_w}]\n'
|
51 |
+
out += f'crop at(x,y): [{self.x}, {self.y}]\n'
|
52 |
+
out += f'crop size(h,w): [{self.h}, {self.w}]\n'
|
53 |
+
out += f'resize crop to(h,w): [{self.out_h}, {self.out_w}]'
|
54 |
+
return out
|
55 |
+
|
56 |
+
|
57 |
+
def fix_randomness(seed=42):
|
58 |
+
random.seed(seed)
|
59 |
+
torch.backends.cudnn.deterministic = True
|
60 |
+
torch.backends.cudnn.benchmark = False
|
61 |
+
torch.manual_seed(seed)
|
62 |
+
np.random.seed(seed)
|
63 |
+
|
64 |
+
|
65 |
+
def worker_init_fn(worker_id):
|
66 |
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
67 |
+
|
68 |
+
|
69 |
+
def float_image_resize(img, shape, interp=PIL.Image.BILINEAR):
|
70 |
+
missing_channel = False
|
71 |
+
if len(img.shape) == 2:
|
72 |
+
missing_channel = True
|
73 |
+
img = img[..., None]
|
74 |
+
layers = []
|
75 |
+
img = img.transpose(2, 0, 1)
|
76 |
+
for l in img:
|
77 |
+
l = np.array(PIL.Image.fromarray(l).resize(shape[::-1], resample=interp))
|
78 |
+
assert l.shape[:2] == shape
|
79 |
+
layers.append(l)
|
80 |
+
if missing_channel:
|
81 |
+
return np.stack(layers, axis=-1)[..., 0]
|
82 |
+
else:
|
83 |
+
return np.stack(layers, axis=-1)
|
84 |
+
|
85 |
+
|
86 |
+
def is_nan(x):
|
87 |
+
"""
|
88 |
+
get mask of nan values.
|
89 |
+
:param x: torch or numpy var.
|
90 |
+
:return: a N-D array of bool. True -> nan, False -> ok.
|
91 |
+
"""
|
92 |
+
return x != x
|
93 |
+
|
94 |
+
|
95 |
+
def has_nan(x) -> bool:
|
96 |
+
"""
|
97 |
+
check whether x contains nan.
|
98 |
+
:param x: torch or numpy var.
|
99 |
+
:return: single bool, True -> x containing nan, False -> ok.
|
100 |
+
"""
|
101 |
+
if x is None:
|
102 |
+
return False
|
103 |
+
return is_nan(x).any()
|
104 |
+
|
105 |
+
|
106 |
+
def confirm(question='OK to continue?'):
|
107 |
+
"""
|
108 |
+
Ask user to enter Y or N (case-insensitive).
|
109 |
+
:return: True if the answer is Y.
|
110 |
+
:rtype: bool
|
111 |
+
"""
|
112 |
+
answer = ""
|
113 |
+
while answer not in ["y", "n"]:
|
114 |
+
answer = input(question + ' [y/n] ').lower()
|
115 |
+
return answer == "y"
|
116 |
+
|
117 |
+
|
118 |
+
def print_notification(content_list, notification_type='NOTIFICATION'):
|
119 |
+
print('---------------------- {0} ----------------------'.format(notification_type))
|
120 |
+
print()
|
121 |
+
for content in content_list:
|
122 |
+
print(content)
|
123 |
+
print()
|
124 |
+
print('----------------------------------------------------')
|
125 |
+
|
126 |
+
|
127 |
+
def torch_img_to_np_img(torch_img):
|
128 |
+
'''convert a torch image to matplotlib-able numpy image
|
129 |
+
torch use Channels x Height x Width
|
130 |
+
numpy use Height x Width x Channels
|
131 |
+
Arguments:
|
132 |
+
torch_img {[type]} -- [description]
|
133 |
+
'''
|
134 |
+
assert isinstance(torch_img, torch.Tensor), 'cannot process data type: {0}'.format(type(torch_img))
|
135 |
+
if len(torch_img.shape) == 4 and (torch_img.shape[1] == 3 or torch_img.shape[1] == 1):
|
136 |
+
return np.transpose(torch_img.detach().cpu().numpy(), (0, 2, 3, 1))
|
137 |
+
if len(torch_img.shape) == 3 and (torch_img.shape[0] == 3 or torch_img.shape[0] == 1):
|
138 |
+
return np.transpose(torch_img.detach().cpu().numpy(), (1, 2, 0))
|
139 |
+
elif len(torch_img.shape) == 2:
|
140 |
+
return torch_img.detach().cpu().numpy()
|
141 |
+
else:
|
142 |
+
raise ValueError('cannot process this image')
|
143 |
+
|
144 |
+
|
145 |
+
def np_img_to_torch_img(np_img):
|
146 |
+
"""convert a numpy image to torch image
|
147 |
+
numpy use Height x Width x Channels
|
148 |
+
torch use Channels x Height x Width
|
149 |
+
|
150 |
+
Arguments:
|
151 |
+
np_img {[type]} -- [description]
|
152 |
+
"""
|
153 |
+
assert isinstance(np_img, np.ndarray), 'cannot process data type: {0}'.format(type(np_img))
|
154 |
+
if len(np_img.shape) == 4 and (np_img.shape[3] == 3 or np_img.shape[3] == 1):
|
155 |
+
return torch.from_numpy(np.transpose(np_img, (0, 3, 1, 2)))
|
156 |
+
if len(np_img.shape) == 3 and (np_img.shape[2] == 3 or np_img.shape[2] == 1):
|
157 |
+
return torch.from_numpy(np.transpose(np_img, (2, 0, 1)))
|
158 |
+
elif len(np_img.shape) == 2:
|
159 |
+
return torch.from_numpy(np_img)
|
160 |
+
else:
|
161 |
+
raise ValueError('cannot process this image with shape: {0}'.format(np_img.shape))
|
162 |
+
|
163 |
+
|
164 |
+
def safe_load_weights(model, saved_weights):
|
165 |
+
try:
|
166 |
+
model.load_state_dict(saved_weights)
|
167 |
+
except RuntimeError:
|
168 |
+
try:
|
169 |
+
weights = saved_weights
|
170 |
+
weights = {k.replace('module.', ''): v for k, v in weights.items()}
|
171 |
+
model.load_state_dict(weights)
|
172 |
+
except RuntimeError:
|
173 |
+
try:
|
174 |
+
weights = saved_weights
|
175 |
+
weights = {'module.' + k: v for k, v in weights.items()}
|
176 |
+
model.load_state_dict(weights)
|
177 |
+
except RuntimeError:
|
178 |
+
try:
|
179 |
+
pretrained_dict = saved_weights
|
180 |
+
model_dict = model.state_dict()
|
181 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if ((k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape))}
|
182 |
+
assert len(pretrained_dict) != 0
|
183 |
+
model_dict.update(pretrained_dict)
|
184 |
+
model.load_state_dict(model_dict)
|
185 |
+
non_match_keys = set(model.state_dict().keys()) - set(pretrained_dict.keys())
|
186 |
+
notification = []
|
187 |
+
notification += ['pretrained weights PARTIALLY loaded, following are missing:']
|
188 |
+
notification += [str(non_match_keys)]
|
189 |
+
print_notification(notification, 'WARNING')
|
190 |
+
except Exception as e:
|
191 |
+
print(f'pretrained weights loading failed {e}')
|
192 |
+
exit()
|
193 |
+
print('weights safely loaded')
|
194 |
+
|
195 |
+
|
196 |
+
def visualize_corrs(img1, img2, corrs, mask=None):
|
197 |
+
if mask is None:
|
198 |
+
mask = np.ones(len(corrs)).astype(bool)
|
199 |
+
|
200 |
+
scale1 = 1.0
|
201 |
+
scale2 = 1.0
|
202 |
+
if img1.shape[1] > img2.shape[1]:
|
203 |
+
scale2 = img1.shape[1] / img2.shape[1]
|
204 |
+
w = img1.shape[1]
|
205 |
+
else:
|
206 |
+
scale1 = img2.shape[1] / img1.shape[1]
|
207 |
+
w = img2.shape[1]
|
208 |
+
# Resize if too big
|
209 |
+
max_w = 400
|
210 |
+
if w > max_w:
|
211 |
+
scale1 *= max_w / w
|
212 |
+
scale2 *= max_w / w
|
213 |
+
img1 = cv2.resize(img1, (0, 0), fx=scale1, fy=scale1)
|
214 |
+
img2 = cv2.resize(img2, (0, 0), fx=scale2, fy=scale2)
|
215 |
+
|
216 |
+
x1, x2 = corrs[:, :2], corrs[:, 2:]
|
217 |
+
h1, w1 = img1.shape[:2]
|
218 |
+
h2, w2 = img2.shape[:2]
|
219 |
+
img = np.zeros((h1 + h2, max(w1, w2), 3), dtype=img1.dtype)
|
220 |
+
img[:h1, :w1] = img1
|
221 |
+
img[h1:, :w2] = img2
|
222 |
+
# Move keypoints to coordinates to image coordinates
|
223 |
+
x1 = x1 * scale1
|
224 |
+
x2 = x2 * scale2
|
225 |
+
# recompute the coordinates for the second image
|
226 |
+
x2p = x2 + np.array([[0, h1]])
|
227 |
+
fig = plt.figure(frameon=False)
|
228 |
+
fig = plt.imshow(img)
|
229 |
+
|
230 |
+
cols = [
|
231 |
+
[0.0, 0.67, 0.0],
|
232 |
+
[0.9, 0.1, 0.1],
|
233 |
+
]
|
234 |
+
lw = .5
|
235 |
+
alpha = 1
|
236 |
+
|
237 |
+
# Draw outliers
|
238 |
+
_x1 = x1[~mask]
|
239 |
+
_x2p = x2p[~mask]
|
240 |
+
xs = np.stack([_x1[:, 0], _x2p[:, 0]], axis=1).T
|
241 |
+
ys = np.stack([_x1[:, 1], _x2p[:, 1]], axis=1).T
|
242 |
+
plt.plot(
|
243 |
+
xs, ys,
|
244 |
+
alpha=alpha,
|
245 |
+
linestyle="-",
|
246 |
+
linewidth=lw,
|
247 |
+
aa=False,
|
248 |
+
color=cols[1],
|
249 |
+
)
|
250 |
+
|
251 |
+
|
252 |
+
# Draw Inliers
|
253 |
+
_x1 = x1[mask]
|
254 |
+
_x2p = x2p[mask]
|
255 |
+
xs = np.stack([_x1[:, 0], _x2p[:, 0]], axis=1).T
|
256 |
+
ys = np.stack([_x1[:, 1], _x2p[:, 1]], axis=1).T
|
257 |
+
plt.plot(
|
258 |
+
xs, ys,
|
259 |
+
alpha=alpha,
|
260 |
+
linestyle="-",
|
261 |
+
linewidth=lw,
|
262 |
+
aa=False,
|
263 |
+
color=cols[0],
|
264 |
+
)
|
265 |
+
plt.scatter(xs, ys)
|
266 |
+
|
267 |
+
fig.axes.get_xaxis().set_visible(False)
|
268 |
+
fig.axes.get_yaxis().set_visible(False)
|
269 |
+
ax = plt.gca()
|
270 |
+
ax.set_axis_off()
|
271 |
+
plt.show()
|
third_party/COTR/LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
third_party/COTR/demo_face.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
COTR demo for human face
|
3 |
+
We use an off-the-shelf face landmarks detector: https://github.com/1adrianb/face-alignment
|
4 |
+
'''
|
5 |
+
import argparse
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import imageio
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
from COTR.utils import utils, debug_utils
|
16 |
+
from COTR.models import build_model
|
17 |
+
from COTR.options.options import *
|
18 |
+
from COTR.options.options_utils import *
|
19 |
+
from COTR.inference.inference_helper import triangulate_corr
|
20 |
+
from COTR.inference.sparse_engine import SparseEngine
|
21 |
+
|
22 |
+
utils.fix_randomness(0)
|
23 |
+
torch.set_grad_enabled(False)
|
24 |
+
|
25 |
+
|
26 |
+
def main(opt):
|
27 |
+
model = build_model(opt)
|
28 |
+
model = model.cuda()
|
29 |
+
weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
|
30 |
+
utils.safe_load_weights(model, weights)
|
31 |
+
model = model.eval()
|
32 |
+
|
33 |
+
img_a = imageio.imread('./sample_data/imgs/face_1.png', pilmode='RGB')
|
34 |
+
img_b = imageio.imread('./sample_data/imgs/face_2.png', pilmode='RGB')
|
35 |
+
queries = np.load('./sample_data/face_landmarks.npy')[0]
|
36 |
+
|
37 |
+
engine = SparseEngine(model, 32, mode='stretching')
|
38 |
+
corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, queries_a=queries, force=False)
|
39 |
+
|
40 |
+
f, axarr = plt.subplots(1, 2)
|
41 |
+
axarr[0].imshow(img_a)
|
42 |
+
axarr[0].scatter(*queries.T, s=1)
|
43 |
+
axarr[0].title.set_text('Reference Face')
|
44 |
+
axarr[0].axis('off')
|
45 |
+
axarr[1].imshow(img_b)
|
46 |
+
axarr[1].scatter(*corrs[:, 2:].T, s=1)
|
47 |
+
axarr[1].title.set_text('Target Face')
|
48 |
+
axarr[1].axis('off')
|
49 |
+
plt.show()
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
parser = argparse.ArgumentParser()
|
54 |
+
set_COTR_arguments(parser)
|
55 |
+
parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
|
56 |
+
parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
|
57 |
+
|
58 |
+
opt = parser.parse_args()
|
59 |
+
opt.command = ' '.join(sys.argv)
|
60 |
+
|
61 |
+
layer_2_channels = {'layer1': 256,
|
62 |
+
'layer2': 512,
|
63 |
+
'layer3': 1024,
|
64 |
+
'layer4': 2048, }
|
65 |
+
opt.dim_feedforward = layer_2_channels[opt.layer]
|
66 |
+
if opt.load_weights:
|
67 |
+
opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
|
68 |
+
print_opt(opt)
|
69 |
+
main(opt)
|
third_party/COTR/demo_guided_matching.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Feature-free COTR guided matching for keypoints.
|
3 |
+
We use DISK(https://github.com/cvlab-epfl/disk) keypoints location.
|
4 |
+
We apply RANSAC + F matrix to further prune outliers.
|
5 |
+
Note: This script doesn't use descriptors.
|
6 |
+
'''
|
7 |
+
import argparse
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
|
11 |
+
import cv2
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import imageio
|
15 |
+
from scipy.spatial import distance_matrix
|
16 |
+
|
17 |
+
from COTR.utils import utils, debug_utils
|
18 |
+
from COTR.models import build_model
|
19 |
+
from COTR.options.options import *
|
20 |
+
from COTR.options.options_utils import *
|
21 |
+
from COTR.inference.sparse_engine import SparseEngine, FasterSparseEngine
|
22 |
+
|
23 |
+
utils.fix_randomness(0)
|
24 |
+
torch.set_grad_enabled(False)
|
25 |
+
|
26 |
+
|
27 |
+
def main(opt):
|
28 |
+
model = build_model(opt)
|
29 |
+
model = model.cuda()
|
30 |
+
weights = torch.load(opt.load_weights_path)['model_state_dict']
|
31 |
+
utils.safe_load_weights(model, weights)
|
32 |
+
model = model.eval()
|
33 |
+
|
34 |
+
img_a = imageio.imread('./sample_data/imgs/21526113_4379776807.jpg')
|
35 |
+
img_b = imageio.imread('./sample_data/imgs/21126421_4537535153.jpg')
|
36 |
+
kp_a = np.load('./sample_data/21526113_4379776807.jpg.disk.kpts.npy')
|
37 |
+
kp_b = np.load('./sample_data/21126421_4537535153.jpg.disk.kpts.npy')
|
38 |
+
|
39 |
+
if opt.faster_infer:
|
40 |
+
engine = FasterSparseEngine(model, 32, mode='tile')
|
41 |
+
else:
|
42 |
+
engine = SparseEngine(model, 32, mode='tile')
|
43 |
+
t0 = time.time()
|
44 |
+
corrs_a_b = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=kp_a.shape[0], queries_a=kp_a, force=True)
|
45 |
+
corrs_b_a = engine.cotr_corr_multiscale(img_b, img_a, np.linspace(0.5, 0.0625, 4), 1, max_corrs=kp_b.shape[0], queries_a=kp_b, force=True)
|
46 |
+
t1 = time.time()
|
47 |
+
print(f'COTR spent {t1-t0} seconds.')
|
48 |
+
inds_a_b = np.argmin(distance_matrix(corrs_a_b[:, 2:], kp_b), axis=1)
|
49 |
+
matched_a_b = np.stack([np.arange(kp_a.shape[0]), inds_a_b]).T
|
50 |
+
inds_b_a = np.argmin(distance_matrix(corrs_b_a[:, 2:], kp_a), axis=1)
|
51 |
+
matched_b_a = np.stack([np.arange(kp_b.shape[0]), inds_b_a]).T
|
52 |
+
|
53 |
+
good = 0
|
54 |
+
final_matches = []
|
55 |
+
for m_ab in matched_a_b:
|
56 |
+
for m_ba in matched_b_a:
|
57 |
+
if (m_ab == m_ba[::-1]).all():
|
58 |
+
good += 1
|
59 |
+
final_matches.append(m_ab)
|
60 |
+
break
|
61 |
+
final_matches = np.array(final_matches)
|
62 |
+
final_corrs = np.concatenate([kp_a[final_matches[:, 0]], kp_b[final_matches[:, 1]]], axis=1)
|
63 |
+
_, mask = cv2.findFundamentalMat(final_corrs[:, :2], final_corrs[:, 2:], cv2.FM_RANSAC, ransacReprojThreshold=5, confidence=0.999999)
|
64 |
+
utils.visualize_corrs(img_a, img_b, final_corrs[np.where(mask[:, 0])])
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
parser = argparse.ArgumentParser()
|
69 |
+
set_COTR_arguments(parser)
|
70 |
+
parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
|
71 |
+
parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
|
72 |
+
parser.add_argument('--faster_infer', type=str2bool, default=False, help='use fatser inference')
|
73 |
+
|
74 |
+
opt = parser.parse_args()
|
75 |
+
opt.command = ' '.join(sys.argv)
|
76 |
+
|
77 |
+
layer_2_channels = {'layer1': 256,
|
78 |
+
'layer2': 512,
|
79 |
+
'layer3': 1024,
|
80 |
+
'layer4': 2048, }
|
81 |
+
opt.dim_feedforward = layer_2_channels[opt.layer]
|
82 |
+
if opt.load_weights:
|
83 |
+
opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
|
84 |
+
print_opt(opt)
|
85 |
+
main(opt)
|
third_party/COTR/demo_homography.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
COTR demo for homography estimation
|
3 |
+
'''
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import imageio
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
|
14 |
+
from COTR.utils import utils, debug_utils
|
15 |
+
from COTR.models import build_model
|
16 |
+
from COTR.options.options import *
|
17 |
+
from COTR.options.options_utils import *
|
18 |
+
from COTR.inference.inference_helper import triangulate_corr
|
19 |
+
from COTR.inference.sparse_engine import SparseEngine
|
20 |
+
|
21 |
+
utils.fix_randomness(0)
|
22 |
+
torch.set_grad_enabled(False)
|
23 |
+
|
24 |
+
|
25 |
+
def main(opt):
|
26 |
+
model = build_model(opt)
|
27 |
+
model = model.cuda()
|
28 |
+
weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
|
29 |
+
utils.safe_load_weights(model, weights)
|
30 |
+
model = model.eval()
|
31 |
+
|
32 |
+
img_a = imageio.imread('./sample_data/imgs/paint_1.JPG', pilmode='RGB')
|
33 |
+
img_b = imageio.imread('./sample_data/imgs/paint_2.jpg', pilmode='RGB')
|
34 |
+
rep_img = imageio.imread('./sample_data/imgs/Meisje_met_de_parel.jpg', pilmode='RGB')
|
35 |
+
rep_mask = np.ones(rep_img.shape[:2])
|
36 |
+
lu_corner = [932, 1025]
|
37 |
+
ru_corner = [2469, 901]
|
38 |
+
lb_corner = [908, 2927]
|
39 |
+
rb_corner = [2436, 3080]
|
40 |
+
queries = np.array([lu_corner, ru_corner, lb_corner, rb_corner]).astype(np.float32)
|
41 |
+
rep_coord = np.array([[0, 0], [rep_img.shape[1], 0], [0, rep_img.shape[0]], [rep_img.shape[1], rep_img.shape[0]]]).astype(np.float32)
|
42 |
+
|
43 |
+
engine = SparseEngine(model, 32, mode='stretching')
|
44 |
+
corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, queries_a=queries, force=True)
|
45 |
+
|
46 |
+
T = cv2.getPerspectiveTransform(rep_coord, corrs[:, 2:].astype(np.float32))
|
47 |
+
vmask = cv2.warpPerspective(rep_mask, T, (img_b.shape[1], img_b.shape[0])) > 0
|
48 |
+
warped = cv2.warpPerspective(rep_img, T, (img_b.shape[1], img_b.shape[0]))
|
49 |
+
out = warped * vmask[..., None] + img_b * (~vmask[..., None])
|
50 |
+
|
51 |
+
f, axarr = plt.subplots(1, 4)
|
52 |
+
axarr[0].imshow(rep_img)
|
53 |
+
axarr[0].title.set_text('Virtual Paint')
|
54 |
+
axarr[0].axis('off')
|
55 |
+
axarr[1].imshow(img_a)
|
56 |
+
axarr[1].title.set_text('Annotated Frame')
|
57 |
+
axarr[1].axis('off')
|
58 |
+
axarr[2].imshow(img_b)
|
59 |
+
axarr[2].title.set_text('Target Frame')
|
60 |
+
axarr[2].axis('off')
|
61 |
+
axarr[3].imshow(out)
|
62 |
+
axarr[3].title.set_text('Overlay')
|
63 |
+
axarr[3].axis('off')
|
64 |
+
plt.show()
|
65 |
+
|
66 |
+
|
67 |
+
if __name__ == "__main__":
|
68 |
+
parser = argparse.ArgumentParser()
|
69 |
+
set_COTR_arguments(parser)
|
70 |
+
parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
|
71 |
+
parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
|
72 |
+
|
73 |
+
opt = parser.parse_args()
|
74 |
+
opt.command = ' '.join(sys.argv)
|
75 |
+
|
76 |
+
layer_2_channels = {'layer1': 256,
|
77 |
+
'layer2': 512,
|
78 |
+
'layer3': 1024,
|
79 |
+
'layer4': 2048, }
|
80 |
+
opt.dim_feedforward = layer_2_channels[opt.layer]
|
81 |
+
if opt.load_weights:
|
82 |
+
opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
|
83 |
+
print_opt(opt)
|
84 |
+
main(opt)
|
third_party/COTR/demo_reconstruction.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
COTR two view reconstruction with known extrinsic/intrinsic demo
|
3 |
+
'''
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import imageio
|
11 |
+
import open3d as o3d
|
12 |
+
|
13 |
+
from COTR.utils import utils, debug_utils
|
14 |
+
from COTR.models import build_model
|
15 |
+
from COTR.options.options import *
|
16 |
+
from COTR.options.options_utils import *
|
17 |
+
from COTR.inference.sparse_engine import SparseEngine, FasterSparseEngine
|
18 |
+
from COTR.projector import pcd_projector
|
19 |
+
|
20 |
+
utils.fix_randomness(0)
|
21 |
+
torch.set_grad_enabled(False)
|
22 |
+
|
23 |
+
|
24 |
+
def triangulate_rays_to_pcd(center_a, dir_a, center_b, dir_b):
|
25 |
+
A = center_a
|
26 |
+
a = dir_a / np.linalg.norm(dir_a, axis=1, keepdims=True)
|
27 |
+
B = center_b
|
28 |
+
b = dir_b / np.linalg.norm(dir_b, axis=1, keepdims=True)
|
29 |
+
c = B - A
|
30 |
+
D = A + a * ((-np.sum(a * b, axis=1) * np.sum(b * c, axis=1) + np.sum(a * c, axis=1) * np.sum(b * b, axis=1)) / (np.sum(a * a, axis=1) * np.sum(b * b, axis=1) - np.sum(a * b, axis=1) * np.sum(a * b, axis=1)))[..., None]
|
31 |
+
return D
|
32 |
+
|
33 |
+
|
34 |
+
def main(opt):
|
35 |
+
model = build_model(opt)
|
36 |
+
model = model.cuda()
|
37 |
+
weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
|
38 |
+
utils.safe_load_weights(model, weights)
|
39 |
+
model = model.eval()
|
40 |
+
|
41 |
+
img_a = imageio.imread('./sample_data/imgs/img_0.jpg', pilmode='RGB')
|
42 |
+
img_b = imageio.imread('./sample_data/imgs/img_1.jpg', pilmode='RGB')
|
43 |
+
|
44 |
+
if opt.faster_infer:
|
45 |
+
engine = FasterSparseEngine(model, 32, mode='tile')
|
46 |
+
else:
|
47 |
+
engine = SparseEngine(model, 32, mode='tile')
|
48 |
+
t0 = time.time()
|
49 |
+
corrs = engine.cotr_corr_multiscale_with_cycle_consistency(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=opt.max_corrs, queries_a=None)
|
50 |
+
t1 = time.time()
|
51 |
+
print(f'spent {t1-t0} seconds for {opt.max_corrs} correspondences.')
|
52 |
+
|
53 |
+
camera_a = np.load('./sample_data/camera_0.npy', allow_pickle=True).item()
|
54 |
+
camera_b = np.load('./sample_data/camera_1.npy', allow_pickle=True).item()
|
55 |
+
center_a = camera_a['cam_center']
|
56 |
+
center_b = camera_b['cam_center']
|
57 |
+
rays_a = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(corrs[:, :2], np.ones([corrs.shape[0], 1]) * 2, camera_a['intrinsic'], motion=camera_a['c2w'])
|
58 |
+
rays_b = pcd_projector.PointCloudProjector.pcd_2d_to_pcd_3d_np(corrs[:, 2:], np.ones([corrs.shape[0], 1]) * 2, camera_b['intrinsic'], motion=camera_b['c2w'])
|
59 |
+
dir_a = rays_a - center_a
|
60 |
+
dir_b = rays_b - center_b
|
61 |
+
center_a = np.array([center_a] * corrs.shape[0])
|
62 |
+
center_b = np.array([center_b] * corrs.shape[0])
|
63 |
+
points = triangulate_rays_to_pcd(center_a, dir_a, center_b, dir_b)
|
64 |
+
colors = (img_a[tuple(np.floor(corrs[:, :2]).astype(int)[:, ::-1].T)] / 255 + img_b[tuple(np.floor(corrs[:, 2:]).astype(int)[:, ::-1].T)] / 255) / 2
|
65 |
+
colors = np.array(colors)
|
66 |
+
|
67 |
+
pcd = o3d.geometry.PointCloud()
|
68 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
69 |
+
pcd.colors = o3d.utility.Vector3dVector(colors)
|
70 |
+
o3d.visualization.draw_geometries([pcd])
|
71 |
+
|
72 |
+
|
73 |
+
if __name__ == "__main__":
|
74 |
+
parser = argparse.ArgumentParser()
|
75 |
+
set_COTR_arguments(parser)
|
76 |
+
parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
|
77 |
+
parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
|
78 |
+
parser.add_argument('--max_corrs', type=int, default=2048, help='number of correspondences')
|
79 |
+
parser.add_argument('--faster_infer', type=str2bool, default=False, help='use fatser inference')
|
80 |
+
|
81 |
+
opt = parser.parse_args()
|
82 |
+
opt.command = ' '.join(sys.argv)
|
83 |
+
|
84 |
+
layer_2_channels = {'layer1': 256,
|
85 |
+
'layer2': 512,
|
86 |
+
'layer3': 1024,
|
87 |
+
'layer4': 2048, }
|
88 |
+
opt.dim_feedforward = layer_2_channels[opt.layer]
|
89 |
+
if opt.load_weights:
|
90 |
+
opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
|
91 |
+
print_opt(opt)
|
92 |
+
main(opt)
|
third_party/COTR/demo_single_pair.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
COTR demo for a single image pair
|
3 |
+
'''
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import imageio
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
|
14 |
+
from COTR.utils import utils, debug_utils
|
15 |
+
from COTR.models import build_model
|
16 |
+
from COTR.options.options import *
|
17 |
+
from COTR.options.options_utils import *
|
18 |
+
from COTR.inference.inference_helper import triangulate_corr
|
19 |
+
from COTR.inference.sparse_engine import SparseEngine
|
20 |
+
|
21 |
+
utils.fix_randomness(0)
|
22 |
+
torch.set_grad_enabled(False)
|
23 |
+
|
24 |
+
|
25 |
+
def main(opt):
|
26 |
+
model = build_model(opt)
|
27 |
+
model = model.cuda()
|
28 |
+
weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict']
|
29 |
+
utils.safe_load_weights(model, weights)
|
30 |
+
model = model.eval()
|
31 |
+
|
32 |
+
img_a = imageio.imread('./sample_data/imgs/cathedral_1.jpg', pilmode='RGB')
|
33 |
+
img_b = imageio.imread('./sample_data/imgs/cathedral_2.jpg', pilmode='RGB')
|
34 |
+
|
35 |
+
engine = SparseEngine(model, 32, mode='tile')
|
36 |
+
t0 = time.time()
|
37 |
+
corrs = engine.cotr_corr_multiscale_with_cycle_consistency(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, max_corrs=opt.max_corrs, queries_a=None)
|
38 |
+
t1 = time.time()
|
39 |
+
|
40 |
+
utils.visualize_corrs(img_a, img_b, corrs)
|
41 |
+
print(f'spent {t1-t0} seconds for {opt.max_corrs} correspondences.')
|
42 |
+
dense = triangulate_corr(corrs, img_a.shape, img_b.shape)
|
43 |
+
warped = cv2.remap(img_b, dense[..., 0].astype(np.float32), dense[..., 1].astype(np.float32), interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT)
|
44 |
+
plt.imshow(warped / 255 * 0.5 + img_a / 255 * 0.5)
|
45 |
+
plt.show()
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == "__main__":
|
49 |
+
parser = argparse.ArgumentParser()
|
50 |
+
set_COTR_arguments(parser)
|
51 |
+
parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
|
52 |
+
parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
|
53 |
+
parser.add_argument('--max_corrs', type=int, default=100, help='number of correspondences')
|
54 |
+
|
55 |
+
opt = parser.parse_args()
|
56 |
+
opt.command = ' '.join(sys.argv)
|
57 |
+
|
58 |
+
layer_2_channels = {'layer1': 256,
|
59 |
+
'layer2': 512,
|
60 |
+
'layer3': 1024,
|
61 |
+
'layer4': 2048, }
|
62 |
+
opt.dim_feedforward = layer_2_channels[opt.layer]
|
63 |
+
if opt.load_weights:
|
64 |
+
opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
|
65 |
+
print_opt(opt)
|
66 |
+
main(opt)
|
third_party/COTR/demo_wbs.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Manually passing scale to COTR, skip the scale difference estimation.
|
3 |
+
'''
|
4 |
+
import argparse
|
5 |
+
import os
|
6 |
+
import time
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import imageio
|
12 |
+
from scipy.spatial import distance_matrix
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
from COTR.utils import utils, debug_utils
|
16 |
+
from COTR.models import build_model
|
17 |
+
from COTR.options.options import *
|
18 |
+
from COTR.options.options_utils import *
|
19 |
+
from COTR.inference.sparse_engine import SparseEngine
|
20 |
+
|
21 |
+
utils.fix_randomness(0)
|
22 |
+
torch.set_grad_enabled(False)
|
23 |
+
|
24 |
+
|
25 |
+
def main(opt):
|
26 |
+
model = build_model(opt)
|
27 |
+
model = model.cuda()
|
28 |
+
weights = torch.load(opt.load_weights_path)['model_state_dict']
|
29 |
+
utils.safe_load_weights(model, weights)
|
30 |
+
model = model.eval()
|
31 |
+
|
32 |
+
img_a = imageio.imread('./sample_data/imgs/petrzin_01.png')
|
33 |
+
img_b = imageio.imread('./sample_data/imgs/petrzin_02.png')
|
34 |
+
img_a_area = 1.0
|
35 |
+
img_b_area = 1.0
|
36 |
+
gt_corrs = np.loadtxt('./sample_data/petrzin_pts.txt')
|
37 |
+
kp_a = gt_corrs[:, :2]
|
38 |
+
kp_b = gt_corrs[:, 2:]
|
39 |
+
|
40 |
+
engine = SparseEngine(model, 32, mode='tile')
|
41 |
+
t0 = time.time()
|
42 |
+
corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.75, 0.1, 4), 1, max_corrs=kp_a.shape[0], queries_a=kp_a, force=True, areas=[img_a_area, img_b_area])
|
43 |
+
t1 = time.time()
|
44 |
+
print(f'COTR spent {t1-t0} seconds.')
|
45 |
+
|
46 |
+
utils.visualize_corrs(img_a, img_b, corrs)
|
47 |
+
plt.imshow(img_b)
|
48 |
+
plt.scatter(kp_b[:,0], kp_b[:,1])
|
49 |
+
plt.scatter(corrs[:,2], corrs[:,3])
|
50 |
+
plt.plot(np.stack([kp_b[:,0], corrs[:,2]], axis=1).T, np.stack([kp_b[:,1], corrs[:,3]], axis=1).T, color=[1,0,0])
|
51 |
+
plt.show()
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
parser = argparse.ArgumentParser()
|
56 |
+
set_COTR_arguments(parser)
|
57 |
+
parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory')
|
58 |
+
parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id')
|
59 |
+
|
60 |
+
opt = parser.parse_args()
|
61 |
+
opt.command = ' '.join(sys.argv)
|
62 |
+
|
63 |
+
layer_2_channels = {'layer1': 256,
|
64 |
+
'layer2': 512,
|
65 |
+
'layer3': 1024,
|
66 |
+
'layer4': 2048, }
|
67 |
+
opt.dim_feedforward = layer_2_channels[opt.layer]
|
68 |
+
if opt.load_weights:
|
69 |
+
opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar')
|
70 |
+
print_opt(opt)
|
71 |
+
main(opt)
|
third_party/COTR/environment.yml
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: cotr_env
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=main
|
7 |
+
- backcall=0.2.0=pyhd3eb1b0_0
|
8 |
+
- blas=1.0=mkl
|
9 |
+
- bzip2=1.0.8=h7b6447c_0
|
10 |
+
- ca-certificates=2021.4.13=h06a4308_1
|
11 |
+
- cairo=1.16.0=hf32fb01_1
|
12 |
+
- certifi=2020.12.5=py37h06a4308_0
|
13 |
+
- cudatoolkit=10.2.89=hfd86e86_1
|
14 |
+
- cycler=0.10.0=py37_0
|
15 |
+
- dbus=1.13.18=hb2f20db_0
|
16 |
+
- decorator=5.0.6=pyhd3eb1b0_0
|
17 |
+
- expat=2.3.0=h2531618_2
|
18 |
+
- ffmpeg=4.0=hcdf2ecd_0
|
19 |
+
- fontconfig=2.13.1=h6c09931_0
|
20 |
+
- freeglut=3.0.0=hf484d3e_5
|
21 |
+
- freetype=2.10.4=h5ab3b9f_0
|
22 |
+
- glib=2.68.1=h36276a3_0
|
23 |
+
- graphite2=1.3.14=h23475e2_0
|
24 |
+
- gst-plugins-base=1.14.0=h8213a91_2
|
25 |
+
- gstreamer=1.14.0=h28cd5cc_2
|
26 |
+
- harfbuzz=1.8.8=hffaf4a1_0
|
27 |
+
- hdf5=1.10.2=hba1933b_1
|
28 |
+
- icu=58.2=he6710b0_3
|
29 |
+
- imageio=2.9.0=pyhd3eb1b0_0
|
30 |
+
- intel-openmp=2021.2.0=h06a4308_610
|
31 |
+
- ipython=7.22.0=py37hb070fc8_0
|
32 |
+
- ipython_genutils=0.2.0=pyhd3eb1b0_1
|
33 |
+
- jasper=2.0.14=h07fcdf6_1
|
34 |
+
- jedi=0.17.0=py37_0
|
35 |
+
- jpeg=9b=h024ee3a_2
|
36 |
+
- kiwisolver=1.3.1=py37h2531618_0
|
37 |
+
- lcms2=2.12=h3be6417_0
|
38 |
+
- ld_impl_linux-64=2.33.1=h53a641e_7
|
39 |
+
- libffi=3.3=he6710b0_2
|
40 |
+
- libgcc-ng=9.1.0=hdf63c60_0
|
41 |
+
- libgfortran-ng=7.3.0=hdf63c60_0
|
42 |
+
- libglu=9.0.0=hf484d3e_1
|
43 |
+
- libopencv=3.4.2=hb342d67_1
|
44 |
+
- libopus=1.3.1=h7b6447c_0
|
45 |
+
- libpng=1.6.37=hbc83047_0
|
46 |
+
- libstdcxx-ng=9.1.0=hdf63c60_0
|
47 |
+
- libtiff=4.1.0=h2733197_1
|
48 |
+
- libuuid=1.0.3=h1bed415_2
|
49 |
+
- libuv=1.40.0=h7b6447c_0
|
50 |
+
- libvpx=1.7.0=h439df22_0
|
51 |
+
- libxcb=1.14=h7b6447c_0
|
52 |
+
- libxml2=2.9.10=hb55368b_3
|
53 |
+
- lz4-c=1.9.3=h2531618_0
|
54 |
+
- matplotlib=3.3.4=py37h06a4308_0
|
55 |
+
- matplotlib-base=3.3.4=py37h62a2d02_0
|
56 |
+
- mkl=2020.2=256
|
57 |
+
- mkl-service=2.3.0=py37he8ac12f_0
|
58 |
+
- mkl_fft=1.3.0=py37h54f3939_0
|
59 |
+
- mkl_random=1.1.1=py37h0573a6f_0
|
60 |
+
- ncurses=6.2=he6710b0_1
|
61 |
+
- ninja=1.10.2=hff7bd54_1
|
62 |
+
- numpy=1.19.2=py37h54aff64_0
|
63 |
+
- numpy-base=1.19.2=py37hfa32c7d_0
|
64 |
+
- olefile=0.46=py37_0
|
65 |
+
- opencv=3.4.2=py37h6fd60c2_1
|
66 |
+
- openssl=1.1.1k=h27cfd23_0
|
67 |
+
- parso=0.8.2=pyhd3eb1b0_0
|
68 |
+
- pcre=8.44=he6710b0_0
|
69 |
+
- pexpect=4.8.0=pyhd3eb1b0_3
|
70 |
+
- pickleshare=0.7.5=pyhd3eb1b0_1003
|
71 |
+
- pillow=8.2.0=py37he98fc37_0
|
72 |
+
- pip=21.0.1=py37h06a4308_0
|
73 |
+
- pixman=0.40.0=h7b6447c_0
|
74 |
+
- prompt-toolkit=3.0.17=pyh06a4308_0
|
75 |
+
- ptyprocess=0.7.0=pyhd3eb1b0_2
|
76 |
+
- py-opencv=3.4.2=py37hb342d67_1
|
77 |
+
- pygments=2.8.1=pyhd3eb1b0_0
|
78 |
+
- pyparsing=2.4.7=pyhd3eb1b0_0
|
79 |
+
- pyqt=5.9.2=py37h05f1152_2
|
80 |
+
- python=3.7.10=hdb3f193_0
|
81 |
+
- python-dateutil=2.8.1=pyhd3eb1b0_0
|
82 |
+
- pytorch=1.7.1=py3.7_cuda10.2.89_cudnn7.6.5_0
|
83 |
+
- qt=5.9.7=h5867ecd_1
|
84 |
+
- readline=8.1=h27cfd23_0
|
85 |
+
- scipy=1.2.1=py37h7c811a0_0
|
86 |
+
- setuptools=52.0.0=py37h06a4308_0
|
87 |
+
- sip=4.19.8=py37hf484d3e_0
|
88 |
+
- six=1.15.0=py37h06a4308_0
|
89 |
+
- sqlite=3.35.4=hdfb4753_0
|
90 |
+
- tk=8.6.10=hbc83047_0
|
91 |
+
- torchaudio=0.7.2=py37
|
92 |
+
- torchvision=0.8.2=py37_cu102
|
93 |
+
- tornado=6.1=py37h27cfd23_0
|
94 |
+
- tqdm=4.59.0=pyhd3eb1b0_1
|
95 |
+
- traitlets=5.0.5=pyhd3eb1b0_0
|
96 |
+
- typing_extensions=3.7.4.3=pyha847dfd_0
|
97 |
+
- vispy=0.5.3=py37hee6b756_0
|
98 |
+
- wcwidth=0.2.5=py_0
|
99 |
+
- wheel=0.36.2=pyhd3eb1b0_0
|
100 |
+
- xz=5.2.5=h7b6447c_0
|
101 |
+
- zlib=1.2.11=h7b6447c_3
|
102 |
+
- zstd=1.4.9=haebb681_0
|
103 |
+
- pip:
|
104 |
+
- tables==3.6.1
|
third_party/COTR/out/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
third_party/COTR/out/.placeholder
ADDED
File without changes
|
third_party/COTR/out/default/checkpoint.pth.tar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:abfa1183408dc566535146b41508ed02084d5f5d1a150f5c188ee479463d6d5c
|
3 |
+
size 219363688
|