Spaces:
Running
Running
Realcat
commited on
Commit
•
2eaeef9
1
Parent(s):
68a65da
update:sift and update lightglue
Browse files- common/config.yaml +13 -1
- common/utils.py +13 -4
- hloc/extract_features.py +3 -2
- hloc/extractors/alike.py +2 -0
- hloc/extractors/d2net.py +4 -0
- hloc/extractors/darkfeat.py +2 -1
- hloc/extractors/dedode.py +2 -3
- hloc/extractors/example.py +1 -0
- hloc/extractors/lanet.py +2 -0
- hloc/extractors/r2d2.py +2 -0
- hloc/extractors/rekd.py +2 -0
- hloc/extractors/rord.py +4 -5
- hloc/extractors/sift.py +224 -0
- hloc/extractors/superpoint.py +2 -0
- hloc/match_dense.py +2 -5
- hloc/match_features.py +38 -14
- hloc/matchers/duster.py +36 -30
- hloc/matchers/lightglue.py +10 -0
- hloc/matchers/sgmnet.py +6 -2
- hloc/matchers/sold2.py +1 -0
- hloc/utils/viz.py +1 -0
- third_party/LightGlue/.flake8 +4 -0
- third_party/LightGlue/.github/workflows/code-quality.yml +24 -0
- third_party/LightGlue/.gitignore +162 -6
- third_party/LightGlue/LICENSE +1 -1
- third_party/LightGlue/README.md +71 -25
- third_party/LightGlue/assets/DSC_0410.JPG +0 -0
- third_party/LightGlue/assets/DSC_0411.JPG +0 -0
- third_party/LightGlue/assets/benchmark.png +3 -0
- third_party/LightGlue/assets/benchmark_cpu.png +3 -0
- third_party/LightGlue/benchmark.py +255 -0
- third_party/LightGlue/demo.ipynb +29 -22
- third_party/LightGlue/lightglue/__init__.py +7 -4
- third_party/LightGlue/lightglue/aliked.py +758 -0
- third_party/LightGlue/lightglue/disk.py +10 -24
- third_party/LightGlue/lightglue/dog_hardnet.py +41 -0
- third_party/LightGlue/lightglue/lightglue.py +331 -146
- third_party/LightGlue/lightglue/sift.py +216 -0
- third_party/LightGlue/lightglue/superpoint.py +21 -36
- third_party/LightGlue/lightglue/utils.py +25 -10
- third_party/LightGlue/lightglue/viz2d.py +1 -1
- third_party/LightGlue/pyproject.toml +30 -0
- third_party/LightGlue/setup.py +0 -27
common/config.yaml
CHANGED
@@ -25,7 +25,7 @@ matcher_zoo:
|
|
25 |
source: "CVPR 2024"
|
26 |
github: https://github.com/Vincentqyw/omniglue-onnx
|
27 |
paper: https://arxiv.org/abs/2405.12979
|
28 |
-
project: https://hwjiang1510.github.io/OmniGlue
|
29 |
display: true
|
30 |
DUSt3R:
|
31 |
# TODO: duster is under development
|
@@ -40,6 +40,7 @@ matcher_zoo:
|
|
40 |
project: https://dust3r.europe.naverlabs.com
|
41 |
display: true
|
42 |
GIM(dkm):
|
|
|
43 |
matcher: gim(dkm)
|
44 |
dense: true
|
45 |
info:
|
@@ -197,6 +198,17 @@ matcher_zoo:
|
|
197 |
paper: https://arxiv.org/abs/1712.07629
|
198 |
project: null
|
199 |
display: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
disk+lightglue:
|
201 |
matcher: disk-lightglue
|
202 |
feature: disk
|
|
|
25 |
source: "CVPR 2024"
|
26 |
github: https://github.com/Vincentqyw/omniglue-onnx
|
27 |
paper: https://arxiv.org/abs/2405.12979
|
28 |
+
project: https://hwjiang1510.github.io/OmniGlue
|
29 |
display: true
|
30 |
DUSt3R:
|
31 |
# TODO: duster is under development
|
|
|
40 |
project: https://dust3r.europe.naverlabs.com
|
41 |
display: true
|
42 |
GIM(dkm):
|
43 |
+
enable: false
|
44 |
matcher: gim(dkm)
|
45 |
dense: true
|
46 |
info:
|
|
|
198 |
paper: https://arxiv.org/abs/1712.07629
|
199 |
project: null
|
200 |
display: false
|
201 |
+
sift+lightglue:
|
202 |
+
matcher: sift-lightglue
|
203 |
+
feature: sift
|
204 |
+
dense: false
|
205 |
+
info:
|
206 |
+
name: LightGlue #dispaly name
|
207 |
+
source: "ICCV 2023"
|
208 |
+
github: https://github.com/cvg/LightGlue
|
209 |
+
paper: https://arxiv.org/pdf/2306.13643
|
210 |
+
project: null
|
211 |
+
display: true
|
212 |
disk+lightglue:
|
213 |
matcher: disk-lightglue
|
214 |
feature: disk
|
common/utils.py
CHANGED
@@ -7,6 +7,7 @@ import psutil
|
|
7 |
import shutil
|
8 |
import numpy as np
|
9 |
import gradio as gr
|
|
|
10 |
from pathlib import Path
|
11 |
import poselib
|
12 |
from itertools import combinations
|
@@ -231,10 +232,10 @@ def gen_examples():
|
|
231 |
return [pairs[i] for i in selected]
|
232 |
|
233 |
# rotated examples
|
234 |
-
def gen_rot_image_pairs(count: int =
|
235 |
path = ROOT / "datasets/sacre_coeur/mapping"
|
236 |
path_rot = ROOT / "datasets/sacre_coeur/mapping_rot"
|
237 |
-
rot_list = [45,
|
238 |
pairs = []
|
239 |
for file in os.listdir(path):
|
240 |
if file.lower().endswith((".jpg", ".jpeg", ".png")):
|
@@ -274,6 +275,7 @@ def gen_examples():
|
|
274 |
# image pair path
|
275 |
pairs = gen_images_pairs()
|
276 |
pairs += gen_rot_image_pairs()
|
|
|
277 |
pairs += gen_image_pairs_wxbs()
|
278 |
|
279 |
match_setting_threshold = DEFAULT_SETTING_THRESHOLD
|
@@ -1015,8 +1017,15 @@ ransac_zoo = {
|
|
1015 |
|
1016 |
|
1017 |
def rotate_image(input_path, degrees, output_path):
|
1018 |
-
from PIL import Image
|
1019 |
-
|
1020 |
img = Image.open(input_path)
|
1021 |
img_rotated = img.rotate(-degrees)
|
1022 |
img_rotated.save(output_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import shutil
|
8 |
import numpy as np
|
9 |
import gradio as gr
|
10 |
+
from PIL import Image
|
11 |
from pathlib import Path
|
12 |
import poselib
|
13 |
from itertools import combinations
|
|
|
232 |
return [pairs[i] for i in selected]
|
233 |
|
234 |
# rotated examples
|
235 |
+
def gen_rot_image_pairs(count: int = 10):
|
236 |
path = ROOT / "datasets/sacre_coeur/mapping"
|
237 |
path_rot = ROOT / "datasets/sacre_coeur/mapping_rot"
|
238 |
+
rot_list = [45, 180, 90, 225, 270]
|
239 |
pairs = []
|
240 |
for file in os.listdir(path):
|
241 |
if file.lower().endswith((".jpg", ".jpeg", ".png")):
|
|
|
275 |
# image pair path
|
276 |
pairs = gen_images_pairs()
|
277 |
pairs += gen_rot_image_pairs()
|
278 |
+
pairs += gen_scale_image_pairs()
|
279 |
pairs += gen_image_pairs_wxbs()
|
280 |
|
281 |
match_setting_threshold = DEFAULT_SETTING_THRESHOLD
|
|
|
1017 |
|
1018 |
|
1019 |
def rotate_image(input_path, degrees, output_path):
|
|
|
|
|
1020 |
img = Image.open(input_path)
|
1021 |
img_rotated = img.rotate(-degrees)
|
1022 |
img_rotated.save(output_path)
|
1023 |
+
|
1024 |
+
|
1025 |
+
def scale_image(input_path, scale_factor, output_path):
|
1026 |
+
img = Image.open(input_path)
|
1027 |
+
width, height = img.size
|
1028 |
+
new_width = int(width * scale_factor)
|
1029 |
+
new_height = int(height * scale_factor)
|
1030 |
+
img_resized = img.resize((new_width, new_height))
|
1031 |
+
img_resized.save(output_path)
|
hloc/extract_features.py
CHANGED
@@ -131,6 +131,7 @@ confs = {
|
|
131 |
"output": "feats-rootsift-n5000-r1600",
|
132 |
"model": {
|
133 |
"name": "dog",
|
|
|
134 |
"max_keypoints": 5000,
|
135 |
},
|
136 |
"preprocessing": {
|
@@ -145,8 +146,8 @@ confs = {
|
|
145 |
"sift": {
|
146 |
"output": "feats-sift-n5000-r1600",
|
147 |
"model": {
|
148 |
-
"name": "
|
149 |
-
"
|
150 |
"max_keypoints": 5000,
|
151 |
},
|
152 |
"preprocessing": {
|
|
|
131 |
"output": "feats-rootsift-n5000-r1600",
|
132 |
"model": {
|
133 |
"name": "dog",
|
134 |
+
"descriptor": "rootsift",
|
135 |
"max_keypoints": 5000,
|
136 |
},
|
137 |
"preprocessing": {
|
|
|
146 |
"sift": {
|
147 |
"output": "feats-sift-n5000-r1600",
|
148 |
"model": {
|
149 |
+
"name": "sift",
|
150 |
+
"rootsift": True,
|
151 |
"max_keypoints": 5000,
|
152 |
},
|
153 |
"preprocessing": {
|
hloc/extractors/alike.py
CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
|
|
3 |
import torch
|
4 |
|
5 |
from ..utils.base_model import BaseModel
|
|
|
6 |
|
7 |
alike_path = Path(__file__).parent / "../../third_party/ALIKE"
|
8 |
sys.path.append(str(alike_path))
|
@@ -33,6 +34,7 @@ class Alike(BaseModel):
|
|
33 |
scores_th=conf["detection_threshold"],
|
34 |
n_limit=conf["max_keypoints"],
|
35 |
)
|
|
|
36 |
|
37 |
def _forward(self, data):
|
38 |
image = data["image"]
|
|
|
3 |
import torch
|
4 |
|
5 |
from ..utils.base_model import BaseModel
|
6 |
+
from hloc import logger
|
7 |
|
8 |
alike_path = Path(__file__).parent / "../../third_party/ALIKE"
|
9 |
sys.path.append(str(alike_path))
|
|
|
34 |
scores_th=conf["detection_threshold"],
|
35 |
n_limit=conf["max_keypoints"],
|
36 |
)
|
37 |
+
logger.info(f"Load Alike model done.")
|
38 |
|
39 |
def _forward(self, data):
|
40 |
image = data["image"]
|
hloc/extractors/d2net.py
CHANGED
@@ -4,13 +4,16 @@ import subprocess
|
|
4 |
import torch
|
5 |
|
6 |
from ..utils.base_model import BaseModel
|
|
|
7 |
|
8 |
d2net_path = Path(__file__).parent / "../../third_party"
|
9 |
sys.path.append(str(d2net_path))
|
10 |
from d2net.lib.model_test import D2Net as _D2Net
|
11 |
from d2net.lib.pyramid import process_multiscale
|
|
|
12 |
d2net_path = Path(__file__).parent / "../../third_party/d2net"
|
13 |
|
|
|
14 |
class D2Net(BaseModel):
|
15 |
default_conf = {
|
16 |
"model_name": "d2_tf.pth",
|
@@ -36,6 +39,7 @@ class D2Net(BaseModel):
|
|
36 |
self.net = _D2Net(
|
37 |
model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
|
38 |
)
|
|
|
39 |
|
40 |
def _forward(self, data):
|
41 |
image = data["image"]
|
|
|
4 |
import torch
|
5 |
|
6 |
from ..utils.base_model import BaseModel
|
7 |
+
from hloc import logger
|
8 |
|
9 |
d2net_path = Path(__file__).parent / "../../third_party"
|
10 |
sys.path.append(str(d2net_path))
|
11 |
from d2net.lib.model_test import D2Net as _D2Net
|
12 |
from d2net.lib.pyramid import process_multiscale
|
13 |
+
|
14 |
d2net_path = Path(__file__).parent / "../../third_party/d2net"
|
15 |
|
16 |
+
|
17 |
class D2Net(BaseModel):
|
18 |
default_conf = {
|
19 |
"model_name": "d2_tf.pth",
|
|
|
39 |
self.net = _D2Net(
|
40 |
model_file=model_file, use_relu=conf["use_relu"], use_cuda=False
|
41 |
)
|
42 |
+
logger.info(f"Load D2Net model done.")
|
43 |
|
44 |
def _forward(self, data):
|
45 |
image = data["image"]
|
hloc/extractors/darkfeat.py
CHANGED
@@ -2,7 +2,7 @@ import sys
|
|
2 |
from pathlib import Path
|
3 |
import subprocess
|
4 |
from ..utils.base_model import BaseModel
|
5 |
-
from
|
6 |
|
7 |
darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
|
8 |
sys.path.append(str(darkfeat_path))
|
@@ -43,6 +43,7 @@ class DarkFeat(BaseModel):
|
|
43 |
raise e
|
44 |
|
45 |
self.net = DarkFeat_(model_path)
|
|
|
46 |
|
47 |
def _forward(self, data):
|
48 |
pred = self.net({"image": data["image"]})
|
|
|
2 |
from pathlib import Path
|
3 |
import subprocess
|
4 |
from ..utils.base_model import BaseModel
|
5 |
+
from hloc import logger
|
6 |
|
7 |
darkfeat_path = Path(__file__).parent / "../../third_party/DarkFeat"
|
8 |
sys.path.append(str(darkfeat_path))
|
|
|
43 |
raise e
|
44 |
|
45 |
self.net = DarkFeat_(model_path)
|
46 |
+
logger.info(f"Load DarkFeat model done.")
|
47 |
|
48 |
def _forward(self, data):
|
49 |
pred = self.net({"image": data["image"]})
|
hloc/extractors/dedode.py
CHANGED
@@ -4,7 +4,7 @@ import subprocess
|
|
4 |
import torch
|
5 |
from PIL import Image
|
6 |
from ..utils.base_model import BaseModel
|
7 |
-
from
|
8 |
import torchvision.transforms as transforms
|
9 |
|
10 |
dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
|
@@ -15,6 +15,7 @@ from DeDoDe.utils import to_pixel_coords
|
|
15 |
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
|
|
18 |
class DeDoDe(BaseModel):
|
19 |
default_conf = {
|
20 |
"name": "dedode",
|
@@ -61,8 +62,6 @@ class DeDoDe(BaseModel):
|
|
61 |
)
|
62 |
subprocess.run(cmd, check=True)
|
63 |
|
64 |
-
logger.info(f"Loading DeDoDe model...")
|
65 |
-
|
66 |
# load the model
|
67 |
weights_detector = torch.load(model_detector_path, map_location="cpu")
|
68 |
weights_descriptor = torch.load(
|
|
|
4 |
import torch
|
5 |
from PIL import Image
|
6 |
from ..utils.base_model import BaseModel
|
7 |
+
from hloc import logger
|
8 |
import torchvision.transforms as transforms
|
9 |
|
10 |
dedode_path = Path(__file__).parent / "../../third_party/DeDoDe"
|
|
|
15 |
|
16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
17 |
|
18 |
+
|
19 |
class DeDoDe(BaseModel):
|
20 |
default_conf = {
|
21 |
"name": "dedode",
|
|
|
62 |
)
|
63 |
subprocess.run(cmd, check=True)
|
64 |
|
|
|
|
|
65 |
# load the model
|
66 |
weights_detector = torch.load(model_detector_path, map_location="cpu")
|
67 |
weights_descriptor = torch.load(
|
hloc/extractors/example.py
CHANGED
@@ -13,6 +13,7 @@ sys.path.append(str(example_path))
|
|
13 |
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
|
|
16 |
class Example(BaseModel):
|
17 |
# change to your default configs
|
18 |
default_conf = {
|
|
|
13 |
|
14 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
15 |
|
16 |
+
|
17 |
class Example(BaseModel):
|
18 |
# change to your default configs
|
19 |
default_conf = {
|
hloc/extractors/lanet.py
CHANGED
@@ -4,6 +4,7 @@ import subprocess
|
|
4 |
import torch
|
5 |
|
6 |
from ..utils.base_model import BaseModel
|
|
|
7 |
|
8 |
lanet_path = Path(__file__).parent / "../../third_party/lanet"
|
9 |
sys.path.append(str(lanet_path))
|
@@ -29,6 +30,7 @@ class LANet(BaseModel):
|
|
29 |
self.net = PointModel(is_test=True)
|
30 |
state_dict = torch.load(model_path, map_location="cpu")
|
31 |
self.net.load_state_dict(state_dict["model_state"])
|
|
|
32 |
|
33 |
def _forward(self, data):
|
34 |
image = data["image"]
|
|
|
4 |
import torch
|
5 |
|
6 |
from ..utils.base_model import BaseModel
|
7 |
+
from hloc import logger
|
8 |
|
9 |
lanet_path = Path(__file__).parent / "../../third_party/lanet"
|
10 |
sys.path.append(str(lanet_path))
|
|
|
30 |
self.net = PointModel(is_test=True)
|
31 |
state_dict = torch.load(model_path, map_location="cpu")
|
32 |
self.net.load_state_dict(state_dict["model_state"])
|
33 |
+
logger.info(f"Load LANet model done.")
|
34 |
|
35 |
def _forward(self, data):
|
36 |
image = data["image"]
|
hloc/extractors/r2d2.py
CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
|
|
3 |
import torchvision.transforms as tvf
|
4 |
|
5 |
from ..utils.base_model import BaseModel
|
|
|
6 |
|
7 |
base_path = Path(__file__).parent / "../../third_party"
|
8 |
sys.path.append(str(base_path))
|
@@ -34,6 +35,7 @@ class R2D2(BaseModel):
|
|
34 |
rel_thr=conf["reliability_threshold"],
|
35 |
rep_thr=conf["repetability_threshold"],
|
36 |
)
|
|
|
37 |
|
38 |
def _forward(self, data):
|
39 |
img = data["image"]
|
|
|
3 |
import torchvision.transforms as tvf
|
4 |
|
5 |
from ..utils.base_model import BaseModel
|
6 |
+
from hloc import logger
|
7 |
|
8 |
base_path = Path(__file__).parent / "../../third_party"
|
9 |
sys.path.append(str(base_path))
|
|
|
35 |
rel_thr=conf["reliability_threshold"],
|
36 |
rep_thr=conf["repetability_threshold"],
|
37 |
)
|
38 |
+
logger.info(f"Load R2D2 model done.")
|
39 |
|
40 |
def _forward(self, data):
|
41 |
img = data["image"]
|
hloc/extractors/rekd.py
CHANGED
@@ -4,6 +4,7 @@ import subprocess
|
|
4 |
import torch
|
5 |
|
6 |
from ..utils.base_model import BaseModel
|
|
|
7 |
|
8 |
rekd_path = Path(__file__).parent / "../../third_party"
|
9 |
sys.path.append(str(rekd_path))
|
@@ -28,6 +29,7 @@ class REKD(BaseModel):
|
|
28 |
self.net = REKD_(is_test=True)
|
29 |
state_dict = torch.load(model_path, map_location="cpu")
|
30 |
self.net.load_state_dict(state_dict["model_state"])
|
|
|
31 |
|
32 |
def _forward(self, data):
|
33 |
image = data["image"]
|
|
|
4 |
import torch
|
5 |
|
6 |
from ..utils.base_model import BaseModel
|
7 |
+
from hloc import logger
|
8 |
|
9 |
rekd_path = Path(__file__).parent / "../../third_party"
|
10 |
sys.path.append(str(rekd_path))
|
|
|
29 |
self.net = REKD_(is_test=True)
|
30 |
state_dict = torch.load(model_path, map_location="cpu")
|
31 |
self.net.load_state_dict(state_dict["model_state"])
|
32 |
+
logger.info(f"Load REKD model done.")
|
33 |
|
34 |
def _forward(self, data):
|
35 |
image = data["image"]
|
hloc/extractors/rord.py
CHANGED
@@ -4,13 +4,14 @@ import subprocess
|
|
4 |
import torch
|
5 |
|
6 |
from ..utils.base_model import BaseModel
|
7 |
-
from
|
8 |
|
9 |
rord_path = Path(__file__).parent / "../../third_party"
|
10 |
sys.path.append(str(rord_path))
|
11 |
from RoRD.lib.model_test import D2Net as _RoRD
|
12 |
from RoRD.lib.pyramid import process_multiscale
|
13 |
|
|
|
14 |
class RoRD(BaseModel):
|
15 |
default_conf = {
|
16 |
"model_name": "rord.pth",
|
@@ -32,9 +33,7 @@ class RoRD(BaseModel):
|
|
32 |
model_path.parent.mkdir(exist_ok=True)
|
33 |
cmd_wo_proxy = ["gdown", link, "-O", str(model_path)]
|
34 |
cmd = ["gdown", link, "-O", str(model_path), "--proxy", self.proxy]
|
35 |
-
logger.info(
|
36 |
-
f"Downloading the RoRD model with `{cmd_wo_proxy}`."
|
37 |
-
)
|
38 |
try:
|
39 |
subprocess.run(cmd_wo_proxy, check=True)
|
40 |
except subprocess.CalledProcessError as e:
|
@@ -44,10 +43,10 @@ class RoRD(BaseModel):
|
|
44 |
except subprocess.CalledProcessError as e:
|
45 |
logger.error(f"Failed to download the RoRD model.")
|
46 |
raise e
|
47 |
-
logger.info("RoRD model loaded.")
|
48 |
self.net = _RoRD(
|
49 |
model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
|
50 |
)
|
|
|
51 |
|
52 |
def _forward(self, data):
|
53 |
image = data["image"]
|
|
|
4 |
import torch
|
5 |
|
6 |
from ..utils.base_model import BaseModel
|
7 |
+
from hloc import logger
|
8 |
|
9 |
rord_path = Path(__file__).parent / "../../third_party"
|
10 |
sys.path.append(str(rord_path))
|
11 |
from RoRD.lib.model_test import D2Net as _RoRD
|
12 |
from RoRD.lib.pyramid import process_multiscale
|
13 |
|
14 |
+
|
15 |
class RoRD(BaseModel):
|
16 |
default_conf = {
|
17 |
"model_name": "rord.pth",
|
|
|
33 |
model_path.parent.mkdir(exist_ok=True)
|
34 |
cmd_wo_proxy = ["gdown", link, "-O", str(model_path)]
|
35 |
cmd = ["gdown", link, "-O", str(model_path), "--proxy", self.proxy]
|
36 |
+
logger.info(f"Downloading the RoRD model with `{cmd_wo_proxy}`.")
|
|
|
|
|
37 |
try:
|
38 |
subprocess.run(cmd_wo_proxy, check=True)
|
39 |
except subprocess.CalledProcessError as e:
|
|
|
43 |
except subprocess.CalledProcessError as e:
|
44 |
logger.error(f"Failed to download the RoRD model.")
|
45 |
raise e
|
|
|
46 |
self.net = _RoRD(
|
47 |
model_file=model_path, use_relu=conf["use_relu"], use_cuda=False
|
48 |
)
|
49 |
+
logger.info(f"Load RoRD model done.")
|
50 |
|
51 |
def _forward(self, data):
|
52 |
image = data["image"]
|
hloc/extractors/sift.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from kornia.color import rgb_to_grayscale
|
7 |
+
from packaging import version
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
|
10 |
+
try:
|
11 |
+
import pycolmap
|
12 |
+
except ImportError:
|
13 |
+
pycolmap = None
|
14 |
+
from hloc import logger
|
15 |
+
from ..utils.base_model import BaseModel
|
16 |
+
|
17 |
+
|
18 |
+
def filter_dog_point(
|
19 |
+
points, scales, angles, image_shape, nms_radius, scores=None
|
20 |
+
):
|
21 |
+
h, w = image_shape
|
22 |
+
ij = np.round(points - 0.5).astype(int).T[::-1]
|
23 |
+
|
24 |
+
# Remove duplicate points (identical coordinates).
|
25 |
+
# Pick highest scale or score
|
26 |
+
s = scales if scores is None else scores
|
27 |
+
buffer = np.zeros((h, w))
|
28 |
+
np.maximum.at(buffer, tuple(ij), s)
|
29 |
+
keep = np.where(buffer[tuple(ij)] == s)[0]
|
30 |
+
|
31 |
+
# Pick lowest angle (arbitrary).
|
32 |
+
ij = ij[:, keep]
|
33 |
+
buffer[:] = np.inf
|
34 |
+
o_abs = np.abs(angles[keep])
|
35 |
+
np.minimum.at(buffer, tuple(ij), o_abs)
|
36 |
+
mask = buffer[tuple(ij)] == o_abs
|
37 |
+
ij = ij[:, mask]
|
38 |
+
keep = keep[mask]
|
39 |
+
|
40 |
+
if nms_radius > 0:
|
41 |
+
# Apply NMS on the remaining points
|
42 |
+
buffer[:] = 0
|
43 |
+
buffer[tuple(ij)] = s[keep] # scores or scale
|
44 |
+
|
45 |
+
local_max = torch.nn.functional.max_pool2d(
|
46 |
+
torch.from_numpy(buffer).unsqueeze(0),
|
47 |
+
kernel_size=nms_radius * 2 + 1,
|
48 |
+
stride=1,
|
49 |
+
padding=nms_radius,
|
50 |
+
).squeeze(0)
|
51 |
+
is_local_max = buffer == local_max.numpy()
|
52 |
+
keep = keep[is_local_max[tuple(ij)]]
|
53 |
+
return keep
|
54 |
+
|
55 |
+
|
56 |
+
def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
|
57 |
+
x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
|
58 |
+
x.clip_(min=eps).sqrt_()
|
59 |
+
return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
|
60 |
+
|
61 |
+
|
62 |
+
def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
|
63 |
+
"""
|
64 |
+
Detect keypoints using OpenCV Detector.
|
65 |
+
Optionally, perform description.
|
66 |
+
Args:
|
67 |
+
features: OpenCV based keypoints detector and descriptor
|
68 |
+
image: Grayscale image of uint8 data type
|
69 |
+
Returns:
|
70 |
+
keypoints: 1D array of detected cv2.KeyPoint
|
71 |
+
scores: 1D array of responses
|
72 |
+
descriptors: 1D array of descriptors
|
73 |
+
"""
|
74 |
+
detections, descriptors = features.detectAndCompute(image, None)
|
75 |
+
points = np.array([k.pt for k in detections], dtype=np.float32)
|
76 |
+
scores = np.array([k.response for k in detections], dtype=np.float32)
|
77 |
+
scales = np.array([k.size for k in detections], dtype=np.float32)
|
78 |
+
angles = np.deg2rad(
|
79 |
+
np.array([k.angle for k in detections], dtype=np.float32)
|
80 |
+
)
|
81 |
+
return points, scores, scales, angles, descriptors
|
82 |
+
|
83 |
+
|
84 |
+
class SIFT(BaseModel):
|
85 |
+
default_conf = {
|
86 |
+
"rootsift": True,
|
87 |
+
"nms_radius": 0, # None to disable filtering entirely.
|
88 |
+
"max_keypoints": 4096,
|
89 |
+
"backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
|
90 |
+
"detection_threshold": 0.0066667, # from COLMAP
|
91 |
+
"edge_threshold": 10,
|
92 |
+
"first_octave": -1, # only used by pycolmap, the default of COLMAP
|
93 |
+
"num_octaves": 4,
|
94 |
+
}
|
95 |
+
|
96 |
+
required_data_keys = ["image"]
|
97 |
+
|
98 |
+
def _init(self, conf):
|
99 |
+
self.conf = OmegaConf.create(self.conf)
|
100 |
+
backend = self.conf.backend
|
101 |
+
if backend.startswith("pycolmap"):
|
102 |
+
if pycolmap is None:
|
103 |
+
raise ImportError(
|
104 |
+
"Cannot find module pycolmap: install it with pip"
|
105 |
+
"or use backend=opencv."
|
106 |
+
)
|
107 |
+
options = {
|
108 |
+
"peak_threshold": self.conf.detection_threshold,
|
109 |
+
"edge_threshold": self.conf.edge_threshold,
|
110 |
+
"first_octave": self.conf.first_octave,
|
111 |
+
"num_octaves": self.conf.num_octaves,
|
112 |
+
"normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
|
113 |
+
}
|
114 |
+
device = (
|
115 |
+
"auto"
|
116 |
+
if backend == "pycolmap"
|
117 |
+
else backend.replace("pycolmap_", "")
|
118 |
+
)
|
119 |
+
if (
|
120 |
+
backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
121 |
+
) and pycolmap.__version__ < "0.5.0":
|
122 |
+
warnings.warn(
|
123 |
+
"The pycolmap CPU SIFT is buggy in version < 0.5.0, "
|
124 |
+
"consider upgrading pycolmap or use the CUDA version.",
|
125 |
+
stacklevel=1,
|
126 |
+
)
|
127 |
+
else:
|
128 |
+
options["max_num_features"] = self.conf.max_keypoints
|
129 |
+
self.sift = pycolmap.Sift(options=options, device=device)
|
130 |
+
elif backend == "opencv":
|
131 |
+
self.sift = cv2.SIFT_create(
|
132 |
+
contrastThreshold=self.conf.detection_threshold,
|
133 |
+
nfeatures=self.conf.max_keypoints,
|
134 |
+
edgeThreshold=self.conf.edge_threshold,
|
135 |
+
nOctaveLayers=self.conf.num_octaves,
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
|
139 |
+
raise ValueError(
|
140 |
+
f"Unknown backend: {backend} not in "
|
141 |
+
f"{{{','.join(backends)}}}."
|
142 |
+
)
|
143 |
+
logger.info(f"Load SIFT model done.")
|
144 |
+
|
145 |
+
def extract_single_image(self, image: torch.Tensor):
|
146 |
+
image_np = image.cpu().numpy().squeeze(0)
|
147 |
+
|
148 |
+
if self.conf.backend.startswith("pycolmap"):
|
149 |
+
if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
|
150 |
+
detections, descriptors = self.sift.extract(image_np)
|
151 |
+
scores = None # Scores are not exposed by COLMAP anymore.
|
152 |
+
else:
|
153 |
+
detections, scores, descriptors = self.sift.extract(image_np)
|
154 |
+
keypoints = detections[:, :2] # Keep only (x, y).
|
155 |
+
scales, angles = detections[:, -2:].T
|
156 |
+
if scores is not None and (
|
157 |
+
self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
158 |
+
):
|
159 |
+
# Set the scores as a combination of abs. response and scale.
|
160 |
+
scores = np.abs(scores) * scales
|
161 |
+
elif self.conf.backend == "opencv":
|
162 |
+
# TODO: Check if opencv keypoints are already in corner convention
|
163 |
+
keypoints, scores, scales, angles, descriptors = run_opencv_sift(
|
164 |
+
self.sift, (image_np * 255.0).astype(np.uint8)
|
165 |
+
)
|
166 |
+
pred = {
|
167 |
+
"keypoints": keypoints,
|
168 |
+
"scales": scales,
|
169 |
+
"oris": angles,
|
170 |
+
"descriptors": descriptors,
|
171 |
+
}
|
172 |
+
if scores is not None:
|
173 |
+
pred["scores"] = scores
|
174 |
+
|
175 |
+
# sometimes pycolmap returns points outside the image. We remove them
|
176 |
+
if self.conf.backend.startswith("pycolmap"):
|
177 |
+
is_inside = (
|
178 |
+
pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
|
179 |
+
).all(-1)
|
180 |
+
pred = {k: v[is_inside] for k, v in pred.items()}
|
181 |
+
|
182 |
+
if self.conf.nms_radius is not None:
|
183 |
+
keep = filter_dog_point(
|
184 |
+
pred["keypoints"],
|
185 |
+
pred["scales"],
|
186 |
+
pred["oris"],
|
187 |
+
image_np.shape,
|
188 |
+
self.conf.nms_radius,
|
189 |
+
scores=pred.get("scores"),
|
190 |
+
)
|
191 |
+
pred = {k: v[keep] for k, v in pred.items()}
|
192 |
+
|
193 |
+
pred = {k: torch.from_numpy(v) for k, v in pred.items()}
|
194 |
+
if scores is not None:
|
195 |
+
# Keep the k keypoints with highest score
|
196 |
+
num_points = self.conf.max_keypoints
|
197 |
+
if num_points is not None and len(pred["keypoints"]) > num_points:
|
198 |
+
indices = torch.topk(pred["scores"], num_points).indices
|
199 |
+
pred = {k: v[indices] for k, v in pred.items()}
|
200 |
+
return pred
|
201 |
+
|
202 |
+
def _forward(self, data: dict) -> dict:
|
203 |
+
image = data["image"]
|
204 |
+
if image.shape[1] == 3:
|
205 |
+
image = rgb_to_grayscale(image)
|
206 |
+
device = image.device
|
207 |
+
image = image.cpu()
|
208 |
+
pred = []
|
209 |
+
for k in range(len(image)):
|
210 |
+
img = image[k]
|
211 |
+
if "image_size" in data.keys():
|
212 |
+
# avoid extracting points in padded areas
|
213 |
+
w, h = data["image_size"][k]
|
214 |
+
img = img[:, :h, :w]
|
215 |
+
p = self.extract_single_image(img)
|
216 |
+
pred.append(p)
|
217 |
+
pred = {
|
218 |
+
k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]
|
219 |
+
}
|
220 |
+
if self.conf.rootsift:
|
221 |
+
pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
|
222 |
+
pred["descriptors"] = pred["descriptors"].permute(0, 2, 1)
|
223 |
+
pred["keypoint_scores"] = pred["scores"].clone()
|
224 |
+
return pred
|
hloc/extractors/superpoint.py
CHANGED
@@ -3,6 +3,7 @@ from pathlib import Path
|
|
3 |
import torch
|
4 |
|
5 |
from ..utils.base_model import BaseModel
|
|
|
6 |
|
7 |
sys.path.append(str(Path(__file__).parent / "../../third_party"))
|
8 |
from SuperGluePretrainedNetwork.models import superpoint # noqa E402
|
@@ -42,6 +43,7 @@ class SuperPoint(BaseModel):
|
|
42 |
if conf["fix_sampling"]:
|
43 |
superpoint.sample_descriptors = sample_descriptors_fix_sampling
|
44 |
self.net = superpoint.SuperPoint(conf)
|
|
|
45 |
|
46 |
def _forward(self, data):
|
47 |
return self.net(data, self.conf)
|
|
|
3 |
import torch
|
4 |
|
5 |
from ..utils.base_model import BaseModel
|
6 |
+
from hloc import logger
|
7 |
|
8 |
sys.path.append(str(Path(__file__).parent / "../../third_party"))
|
9 |
from SuperGluePretrainedNetwork.models import superpoint # noqa E402
|
|
|
43 |
if conf["fix_sampling"]:
|
44 |
superpoint.sample_descriptors = sample_descriptors_fix_sampling
|
45 |
self.net = superpoint.SuperPoint(conf)
|
46 |
+
logger.info(f"Load SuperPoint model done.")
|
47 |
|
48 |
def _forward(self, data):
|
49 |
return self.net(data, self.conf)
|
hloc/match_dense.py
CHANGED
@@ -138,11 +138,8 @@ confs = {
|
|
138 |
},
|
139 |
"preprocessing": {
|
140 |
"grayscale": False,
|
141 |
-
"
|
142 |
-
"
|
143 |
-
"width": 512,
|
144 |
-
"height": 512,
|
145 |
-
"dfactor": 8,
|
146 |
},
|
147 |
},
|
148 |
"xfeat_dense": {
|
|
|
138 |
},
|
139 |
"preprocessing": {
|
140 |
"grayscale": False,
|
141 |
+
"resize_max": 512,
|
142 |
+
"dfactor": 16,
|
|
|
|
|
|
|
143 |
},
|
144 |
},
|
145 |
"xfeat_dense": {
|
hloc/match_features.py
CHANGED
@@ -63,7 +63,7 @@ confs = {
|
|
63 |
},
|
64 |
},
|
65 |
"disk-lightglue": {
|
66 |
-
"output": "matches-lightglue",
|
67 |
"model": {
|
68 |
"name": "lightglue",
|
69 |
"match_threshold": 0.2,
|
@@ -79,6 +79,24 @@ confs = {
|
|
79 |
"force_resize": False,
|
80 |
},
|
81 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
"sgmnet": {
|
83 |
"output": "matches-sgmnet",
|
84 |
"model": {
|
@@ -339,19 +357,25 @@ def match_images(model, feat0, feat1):
|
|
339 |
feat0["keypoints"] = feat0["keypoints"][0][None]
|
340 |
if isinstance(feat1["keypoints"], list):
|
341 |
feat1["keypoints"] = feat1["keypoints"][0][None]
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
}
|
354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
355 |
pred = {
|
356 |
k: v.cpu().detach()[0] if isinstance(v, torch.Tensor) else v
|
357 |
for k, v in pred.items()
|
|
|
63 |
},
|
64 |
},
|
65 |
"disk-lightglue": {
|
66 |
+
"output": "matches-disk-lightglue",
|
67 |
"model": {
|
68 |
"name": "lightglue",
|
69 |
"match_threshold": 0.2,
|
|
|
79 |
"force_resize": False,
|
80 |
},
|
81 |
},
|
82 |
+
"sift-lightglue": {
|
83 |
+
"output": "matches-sift-lightglue",
|
84 |
+
"model": {
|
85 |
+
"name": "lightglue",
|
86 |
+
"match_threshold": 0.2,
|
87 |
+
"width_confidence": 0.99, # for point pruning
|
88 |
+
"depth_confidence": 0.95, # for early stopping,
|
89 |
+
"features": "sift",
|
90 |
+
"add_scale_ori": True,
|
91 |
+
"model_name": "sift_lightglue.pth",
|
92 |
+
},
|
93 |
+
"preprocessing": {
|
94 |
+
"grayscale": True,
|
95 |
+
"resize_max": 1024,
|
96 |
+
"dfactor": 8,
|
97 |
+
"force_resize": False,
|
98 |
+
},
|
99 |
+
},
|
100 |
"sgmnet": {
|
101 |
"output": "matches-sgmnet",
|
102 |
"model": {
|
|
|
357 |
feat0["keypoints"] = feat0["keypoints"][0][None]
|
358 |
if isinstance(feat1["keypoints"], list):
|
359 |
feat1["keypoints"] = feat1["keypoints"][0][None]
|
360 |
+
input_dict = {
|
361 |
+
"image0": feat0["image"],
|
362 |
+
"keypoints0": feat0["keypoints"],
|
363 |
+
"scores0": feat0["scores"][0].unsqueeze(0),
|
364 |
+
"descriptors0": desc0,
|
365 |
+
"image1": feat1["image"],
|
366 |
+
"keypoints1": feat1["keypoints"],
|
367 |
+
"scores1": feat1["scores"][0].unsqueeze(0),
|
368 |
+
"descriptors1": desc1,
|
369 |
+
}
|
370 |
+
if "scales" in feat0:
|
371 |
+
input_dict = {**input_dict, "scales0": feat0["scales"]}
|
372 |
+
if "scales" in feat1:
|
373 |
+
input_dict = {**input_dict, "scales1": feat1["scales"]}
|
374 |
+
if "oris" in feat0:
|
375 |
+
input_dict = {**input_dict, "oris0": feat0["oris"]}
|
376 |
+
if "oris" in feat1:
|
377 |
+
input_dict = {**input_dict, "oris1": feat1["oris"]}
|
378 |
+
pred = model(input_dict)
|
379 |
pred = {
|
380 |
k: v.cpu().detach()[0] if isinstance(v, torch.Tensor) else v
|
381 |
for k, v in pred.items()
|
hloc/matchers/duster.py
CHANGED
@@ -13,7 +13,7 @@ duster_path = Path(__file__).parent / "../../third_party/dust3r"
|
|
13 |
sys.path.append(str(duster_path))
|
14 |
|
15 |
from dust3r.inference import inference
|
16 |
-
from dust3r.model import load_model
|
17 |
from dust3r.image_pairs import make_pairs
|
18 |
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
19 |
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
|
@@ -33,7 +33,11 @@ class Duster(BaseModel):
|
|
33 |
self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
34 |
self.model_path = self.conf["model_path"]
|
35 |
self.download_weights()
|
36 |
-
self.net = load_model(self.model_path, device)
|
|
|
|
|
|
|
|
|
37 |
logger.info(f"Loaded Dust3r model")
|
38 |
|
39 |
def download_weights(self):
|
@@ -68,8 +72,11 @@ class Duster(BaseModel):
|
|
68 |
|
69 |
def _forward(self, data):
|
70 |
img0, img1 = data["image0"], data["image1"]
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
73 |
|
74 |
images = [
|
75 |
{"img": img0, "idx": 0, "instance": 0},
|
@@ -79,22 +86,13 @@ class Duster(BaseModel):
|
|
79 |
images, scene_graph="complete", prefilter=None, symmetrize=True
|
80 |
)
|
81 |
output = inference(pairs, self.net, device, batch_size=1)
|
82 |
-
|
83 |
scene = global_aligner(
|
84 |
output, device=device, mode=GlobalAlignerMode.PairViewer
|
85 |
)
|
86 |
-
batch_size = 1
|
87 |
-
schedule = "cosine"
|
88 |
-
lr = 0.01
|
89 |
-
niter = 300
|
90 |
-
loss = scene.compute_global_alignment(
|
91 |
-
init="mst", niter=niter, schedule=schedule, lr=lr
|
92 |
-
)
|
93 |
-
|
94 |
# retrieve useful values from scene:
|
|
|
95 |
confidence_masks = scene.get_masks()
|
96 |
pts3d = scene.get_pts3d()
|
97 |
-
imgs = scene.imgs
|
98 |
pts2d_list, pts3d_list = [], []
|
99 |
for i in range(2):
|
100 |
conf_i = confidence_masks[i].cpu().numpy()
|
@@ -102,21 +100,29 @@ class Duster(BaseModel):
|
|
102 |
xy_grid(*imgs[i].shape[:2][::-1])[conf_i]
|
103 |
) # imgs[i].shape[:2] = (H, W)
|
104 |
pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
|
105 |
-
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
|
106 |
-
*pts3d_list
|
107 |
-
)
|
108 |
-
logger.info(f"Found {num_matches} matches")
|
109 |
-
mkpts1 = pts2d_list[1][reciprocal_in_P2]
|
110 |
-
mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
|
111 |
-
|
112 |
-
top_k = self.conf["max_keypoints"]
|
113 |
-
if top_k is not None and len(mkpts0) > top_k:
|
114 |
-
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
|
115 |
-
mkpts0 = mkpts0[keep]
|
116 |
-
mkpts1 = mkpts1[keep]
|
117 |
-
pred = {
|
118 |
-
"keypoints0": torch.from_numpy(mkpts0),
|
119 |
-
"keypoints1": torch.from_numpy(mkpts1),
|
120 |
-
}
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
return pred
|
|
|
13 |
sys.path.append(str(duster_path))
|
14 |
|
15 |
from dust3r.inference import inference
|
16 |
+
from dust3r.model import load_model, AsymmetricCroCo3DStereo
|
17 |
from dust3r.image_pairs import make_pairs
|
18 |
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
|
19 |
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid
|
|
|
33 |
self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
34 |
self.model_path = self.conf["model_path"]
|
35 |
self.download_weights()
|
36 |
+
# self.net = load_model(self.model_path, device)
|
37 |
+
self.net = AsymmetricCroCo3DStereo.from_pretrained(
|
38 |
+
self.model_path
|
39 |
+
# "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
|
40 |
+
).to(device)
|
41 |
logger.info(f"Loaded Dust3r model")
|
42 |
|
43 |
def download_weights(self):
|
|
|
72 |
|
73 |
def _forward(self, data):
|
74 |
img0, img1 = data["image0"], data["image1"]
|
75 |
+
mean = torch.tensor([0.5, 0.5, 0.5]).to(device)
|
76 |
+
std = torch.tensor([0.5, 0.5, 0.5]).to(device)
|
77 |
+
|
78 |
+
img0 = (img0 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
|
79 |
+
img1 = (img1 - mean.view(1, 3, 1, 1)) / std.view(1, 3, 1, 1)
|
80 |
|
81 |
images = [
|
82 |
{"img": img0, "idx": 0, "instance": 0},
|
|
|
86 |
images, scene_graph="complete", prefilter=None, symmetrize=True
|
87 |
)
|
88 |
output = inference(pairs, self.net, device, batch_size=1)
|
|
|
89 |
scene = global_aligner(
|
90 |
output, device=device, mode=GlobalAlignerMode.PairViewer
|
91 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
# retrieve useful values from scene:
|
93 |
+
imgs = scene.imgs
|
94 |
confidence_masks = scene.get_masks()
|
95 |
pts3d = scene.get_pts3d()
|
|
|
96 |
pts2d_list, pts3d_list = [], []
|
97 |
for i in range(2):
|
98 |
conf_i = confidence_masks[i].cpu().numpy()
|
|
|
100 |
xy_grid(*imgs[i].shape[:2][::-1])[conf_i]
|
101 |
) # imgs[i].shape[:2] = (H, W)
|
102 |
pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
+
if len(pts3d_list[1]) == 0:
|
105 |
+
pred = {
|
106 |
+
"keypoints0": torch.zeros([0, 2]),
|
107 |
+
"keypoints1": torch.zeros([0, 2]),
|
108 |
+
}
|
109 |
+
logger.warning(f"Matched {0} points")
|
110 |
+
else:
|
111 |
+
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
|
112 |
+
*pts3d_list
|
113 |
+
)
|
114 |
+
logger.info(f"Found {num_matches} matches")
|
115 |
+
mkpts1 = pts2d_list[1][reciprocal_in_P2]
|
116 |
+
mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
|
117 |
+
top_k = self.conf["max_keypoints"]
|
118 |
+
if top_k is not None and len(mkpts0) > top_k:
|
119 |
+
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(
|
120 |
+
int
|
121 |
+
)
|
122 |
+
mkpts0 = mkpts0[keep]
|
123 |
+
mkpts1 = mkpts1[keep]
|
124 |
+
pred = {
|
125 |
+
"keypoints0": torch.from_numpy(mkpts0),
|
126 |
+
"keypoints1": torch.from_numpy(mkpts1),
|
127 |
+
}
|
128 |
return pred
|
hloc/matchers/lightglue.py
CHANGED
@@ -18,6 +18,7 @@ class LightGlue(BaseModel):
|
|
18 |
"model_name": "superpoint_lightglue.pth",
|
19 |
"flash": True, # enable FlashAttention if available.
|
20 |
"mp": False, # enable mixed precision
|
|
|
21 |
}
|
22 |
required_inputs = [
|
23 |
"image0",
|
@@ -44,9 +45,18 @@ class LightGlue(BaseModel):
|
|
44 |
"keypoints": data["keypoints0"],
|
45 |
"descriptors": data["descriptors0"].permute(0, 2, 1),
|
46 |
}
|
|
|
|
|
|
|
|
|
|
|
47 |
input["image1"] = {
|
48 |
"image": data["image1"],
|
49 |
"keypoints": data["keypoints1"],
|
50 |
"descriptors": data["descriptors1"].permute(0, 2, 1),
|
51 |
}
|
|
|
|
|
|
|
|
|
52 |
return self.net(input)
|
|
|
18 |
"model_name": "superpoint_lightglue.pth",
|
19 |
"flash": True, # enable FlashAttention if available.
|
20 |
"mp": False, # enable mixed precision
|
21 |
+
"add_scale_ori": False,
|
22 |
}
|
23 |
required_inputs = [
|
24 |
"image0",
|
|
|
45 |
"keypoints": data["keypoints0"],
|
46 |
"descriptors": data["descriptors0"].permute(0, 2, 1),
|
47 |
}
|
48 |
+
if "scales0" in data:
|
49 |
+
input["image0"] = {**input["image0"], "scales": data["scales0"]}
|
50 |
+
if "oris0" in data:
|
51 |
+
input["image0"] = {**input["image0"], "oris": data["oris0"]}
|
52 |
+
|
53 |
input["image1"] = {
|
54 |
"image": data["image1"],
|
55 |
"keypoints": data["keypoints1"],
|
56 |
"descriptors": data["descriptors1"].permute(0, 2, 1),
|
57 |
}
|
58 |
+
if "scales1" in data:
|
59 |
+
input["image1"] = {**input["image1"], "scales": data["scales1"]}
|
60 |
+
if "oris1" in data:
|
61 |
+
input["image1"] = {**input["image1"], "oris": data["oris1"]}
|
62 |
return self.net(input)
|
hloc/matchers/sgmnet.py
CHANGED
@@ -99,8 +99,12 @@ class SGMNet(BaseModel):
|
|
99 |
score2 = data["scores1"].reshape(-1, 1)
|
100 |
desc1 = data["descriptors0"].permute(0, 2, 1) # 1 x N x 128
|
101 |
desc2 = data["descriptors1"].permute(0, 2, 1)
|
102 |
-
size1 =
|
103 |
-
|
|
|
|
|
|
|
|
|
104 |
norm_x1 = self.normalize_size(x1, size1)
|
105 |
norm_x2 = self.normalize_size(x2, size2)
|
106 |
|
|
|
99 |
score2 = data["scores1"].reshape(-1, 1)
|
100 |
desc1 = data["descriptors0"].permute(0, 2, 1) # 1 x N x 128
|
101 |
desc2 = data["descriptors1"].permute(0, 2, 1)
|
102 |
+
size1 = (
|
103 |
+
torch.tensor(data["image0"].shape[2:]).flip(0).to(x1.device)
|
104 |
+
) # W x H -> x & y
|
105 |
+
size2 = (
|
106 |
+
torch.tensor(data["image1"].shape[2:]).flip(0).to(x2.device)
|
107 |
+
) # W x H
|
108 |
norm_x1 = self.normalize_size(x1, size1)
|
109 |
norm_x2 = self.normalize_size(x2, size2)
|
110 |
|
hloc/matchers/sold2.py
CHANGED
@@ -34,6 +34,7 @@ class SOLD2(BaseModel):
|
|
34 |
weight_urls = {
|
35 |
"sold2_wireframe.tar": "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download",
|
36 |
}
|
|
|
37 |
# Initialize the line matcher
|
38 |
def _init(self, conf):
|
39 |
checkpoint_path = conf["checkpoint_dir"] / conf["weights"]
|
|
|
34 |
weight_urls = {
|
35 |
"sold2_wireframe.tar": "https://www.polybox.ethz.ch/index.php/s/blOrW89gqSLoHOk/download",
|
36 |
}
|
37 |
+
|
38 |
# Initialize the line matcher
|
39 |
def _init(self, conf):
|
40 |
checkpoint_path = conf["checkpoint_dir"] / conf["weights"]
|
hloc/utils/viz.py
CHANGED
@@ -71,6 +71,7 @@ def plot_keypoints(kpts, colors="lime", ps=4):
|
|
71 |
except IndexError as e:
|
72 |
pass
|
73 |
|
|
|
74 |
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
|
75 |
"""Plot matches for a pair of existing images.
|
76 |
Args:
|
|
|
71 |
except IndexError as e:
|
72 |
pass
|
73 |
|
74 |
+
|
75 |
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
|
76 |
"""Plot matches for a pair of existing images.
|
77 |
Args:
|
third_party/LightGlue/.flake8
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
max-line-length = 88
|
3 |
+
extend-ignore = E203
|
4 |
+
exclude = .git,__pycache__,build,.venv/
|
third_party/LightGlue/.github/workflows/code-quality.yml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Format and Lint Checks
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches:
|
5 |
+
- main
|
6 |
+
paths:
|
7 |
+
- '*.py'
|
8 |
+
pull_request:
|
9 |
+
types: [ assigned, opened, synchronize, reopened ]
|
10 |
+
jobs:
|
11 |
+
check:
|
12 |
+
name: Format and Lint Checks
|
13 |
+
runs-on: ubuntu-latest
|
14 |
+
steps:
|
15 |
+
- uses: actions/checkout@v3
|
16 |
+
- uses: actions/setup-python@v4
|
17 |
+
with:
|
18 |
+
python-version: '3.10'
|
19 |
+
cache: 'pip'
|
20 |
+
- run: python -m pip install --upgrade pip
|
21 |
+
- run: python -m pip install .[dev]
|
22 |
+
- run: python -m flake8 .
|
23 |
+
- run: python -m isort . --check-only --diff
|
24 |
+
- run: python -m black . --check --diff
|
third_party/LightGlue/.gitignore
CHANGED
@@ -1,10 +1,166 @@
|
|
1 |
-
*.egg-info
|
2 |
-
*.pyc
|
3 |
-
/.idea/
|
4 |
/data/
|
5 |
/outputs/
|
6 |
-
__pycache__
|
7 |
/lightglue/weights/
|
8 |
-
lightglue/_flash/
|
9 |
*-checkpoint.ipynb
|
10 |
-
*.pth
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
/data/
|
2 |
/outputs/
|
|
|
3 |
/lightglue/weights/
|
|
|
4 |
*-checkpoint.ipynb
|
5 |
+
*.pth
|
6 |
+
|
7 |
+
# Byte-compiled / optimized / DLL files
|
8 |
+
__pycache__/
|
9 |
+
*.py[cod]
|
10 |
+
*$py.class
|
11 |
+
|
12 |
+
# C extensions
|
13 |
+
*.so
|
14 |
+
|
15 |
+
# Distribution / packaging
|
16 |
+
.Python
|
17 |
+
build/
|
18 |
+
develop-eggs/
|
19 |
+
dist/
|
20 |
+
downloads/
|
21 |
+
eggs/
|
22 |
+
.eggs/
|
23 |
+
lib/
|
24 |
+
lib64/
|
25 |
+
parts/
|
26 |
+
sdist/
|
27 |
+
var/
|
28 |
+
wheels/
|
29 |
+
share/python-wheels/
|
30 |
+
*.egg-info/
|
31 |
+
.installed.cfg
|
32 |
+
*.egg
|
33 |
+
MANIFEST
|
34 |
+
|
35 |
+
# PyInstaller
|
36 |
+
# Usually these files are written by a python script from a template
|
37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
38 |
+
*.manifest
|
39 |
+
*.spec
|
40 |
+
|
41 |
+
# Installer logs
|
42 |
+
pip-log.txt
|
43 |
+
pip-delete-this-directory.txt
|
44 |
+
|
45 |
+
# Unit test / coverage reports
|
46 |
+
htmlcov/
|
47 |
+
.tox/
|
48 |
+
.nox/
|
49 |
+
.coverage
|
50 |
+
.coverage.*
|
51 |
+
.cache
|
52 |
+
nosetests.xml
|
53 |
+
coverage.xml
|
54 |
+
*.cover
|
55 |
+
*.py,cover
|
56 |
+
.hypothesis/
|
57 |
+
.pytest_cache/
|
58 |
+
cover/
|
59 |
+
|
60 |
+
# Translations
|
61 |
+
*.mo
|
62 |
+
*.pot
|
63 |
+
|
64 |
+
# Django stuff:
|
65 |
+
*.log
|
66 |
+
local_settings.py
|
67 |
+
db.sqlite3
|
68 |
+
db.sqlite3-journal
|
69 |
+
|
70 |
+
# Flask stuff:
|
71 |
+
instance/
|
72 |
+
.webassets-cache
|
73 |
+
|
74 |
+
# Scrapy stuff:
|
75 |
+
.scrapy
|
76 |
+
|
77 |
+
# Sphinx documentation
|
78 |
+
docs/_build/
|
79 |
+
|
80 |
+
# PyBuilder
|
81 |
+
.pybuilder/
|
82 |
+
target/
|
83 |
+
|
84 |
+
# Jupyter Notebook
|
85 |
+
.ipynb_checkpoints
|
86 |
+
|
87 |
+
# IPython
|
88 |
+
profile_default/
|
89 |
+
ipython_config.py
|
90 |
+
|
91 |
+
# pyenv
|
92 |
+
# For a library or package, you might want to ignore these files since the code is
|
93 |
+
# intended to run in multiple environments; otherwise, check them in:
|
94 |
+
# .python-version
|
95 |
+
|
96 |
+
# pipenv
|
97 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
98 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
99 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
100 |
+
# install all needed dependencies.
|
101 |
+
#Pipfile.lock
|
102 |
+
|
103 |
+
# poetry
|
104 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
105 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
106 |
+
# commonly ignored for libraries.
|
107 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
108 |
+
#poetry.lock
|
109 |
+
|
110 |
+
# pdm
|
111 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
112 |
+
#pdm.lock
|
113 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
114 |
+
# in version control.
|
115 |
+
# https://pdm.fming.dev/#use-with-ide
|
116 |
+
.pdm.toml
|
117 |
+
|
118 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
119 |
+
__pypackages__/
|
120 |
+
|
121 |
+
# Celery stuff
|
122 |
+
celerybeat-schedule
|
123 |
+
celerybeat.pid
|
124 |
+
|
125 |
+
# SageMath parsed files
|
126 |
+
*.sage.py
|
127 |
+
|
128 |
+
# Environments
|
129 |
+
.env
|
130 |
+
.venv
|
131 |
+
env/
|
132 |
+
venv/
|
133 |
+
ENV/
|
134 |
+
env.bak/
|
135 |
+
venv.bak/
|
136 |
+
|
137 |
+
# Spyder project settings
|
138 |
+
.spyderproject
|
139 |
+
.spyproject
|
140 |
+
|
141 |
+
# Rope project settings
|
142 |
+
.ropeproject
|
143 |
+
|
144 |
+
# mkdocs documentation
|
145 |
+
/site
|
146 |
+
|
147 |
+
# mypy
|
148 |
+
.mypy_cache/
|
149 |
+
.dmypy.json
|
150 |
+
dmypy.json
|
151 |
+
|
152 |
+
# Pyre type checker
|
153 |
+
.pyre/
|
154 |
+
|
155 |
+
# pytype static type analyzer
|
156 |
+
.pytype/
|
157 |
+
|
158 |
+
# Cython debug symbols
|
159 |
+
cython_debug/
|
160 |
+
|
161 |
+
# PyCharm
|
162 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
163 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
164 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
165 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
166 |
+
.idea/
|
third_party/LightGlue/LICENSE
CHANGED
@@ -186,7 +186,7 @@
|
|
186 |
same "printed page" as the copyright notice for easier
|
187 |
identification within third-party archives.
|
188 |
|
189 |
-
Copyright
|
190 |
|
191 |
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
you may not use this file except in compliance with the License.
|
|
|
186 |
same "printed page" as the copyright notice for easier
|
187 |
identification within third-party archives.
|
188 |
|
189 |
+
Copyright 2023 ETH Zurich
|
190 |
|
191 |
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
you may not use this file except in compliance with the License.
|
third_party/LightGlue/README.md
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
<p align="center">
|
2 |
-
<h1 align="center"><ins>LightGlue
|
3 |
<p align="center">
|
4 |
<a href="https://www.linkedin.com/in/philipplindenberger/">Philipp Lindenberger</a>
|
5 |
·
|
@@ -7,15 +7,14 @@
|
|
7 |
·
|
8 |
<a href="https://www.microsoft.com/en-us/research/people/mapoll/">Marc Pollefeys</a>
|
9 |
</p>
|
10 |
-
|
11 |
-
<
|
12 |
-
</p> -->
|
13 |
-
<!-- <h2 align="center">PrePrint 2023</h2> -->
|
14 |
-
<h2 align="center"><p>
|
15 |
<a href="https://arxiv.org/pdf/2306.13643.pdf" align="center">Paper</a> |
|
16 |
-
<a href="https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb" align="center">Colab</a>
|
17 |
-
|
18 |
-
|
|
|
|
|
19 |
</p>
|
20 |
<p align="center">
|
21 |
<a href="https://arxiv.org/abs/2306.13643"><img src="assets/easy_hard.jpg" alt="example" width=80%></a>
|
@@ -27,8 +26,8 @@
|
|
27 |
|
28 |
This repository hosts the inference code of LightGlue, a lightweight feature matcher with high accuracy and blazing fast inference. It takes as input a set of keypoints and descriptors for each image and returns the indices of corresponding points. The architecture is based on adaptive pruning techniques, in both network width and depth - [check out the paper for more details](https://arxiv.org/pdf/2306.13643.pdf).
|
29 |
|
30 |
-
We release pretrained weights of LightGlue with [SuperPoint](https://arxiv.org/abs/1712.07629)
|
31 |
-
The training
|
32 |
|
33 |
## Installation and demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb)
|
34 |
|
@@ -44,14 +43,14 @@ We provide a [demo notebook](demo.ipynb) which shows how to perform feature extr
|
|
44 |
Here is a minimal script to match two images:
|
45 |
|
46 |
```python
|
47 |
-
from lightglue import LightGlue, SuperPoint, DISK
|
48 |
from lightglue.utils import load_image, rbd
|
49 |
|
50 |
# SuperPoint+LightGlue
|
51 |
extractor = SuperPoint(max_num_keypoints=2048).eval().cuda() # load the extractor
|
52 |
matcher = LightGlue(features='superpoint').eval().cuda() # load the matcher
|
53 |
|
54 |
-
# or DISK+LightGlue
|
55 |
extractor = DISK(max_num_keypoints=2048).eval().cuda() # load the extractor
|
56 |
matcher = LightGlue(features='disk').eval().cuda() # load the matcher
|
57 |
|
@@ -88,6 +87,18 @@ feats0, feats1, matches01 = match_pair(extractor, matcher, image0, image1)
|
|
88 |
|
89 |
## Advanced configuration
|
90 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
The default values give a good trade-off between speed and accuracy. To maximize the accuracy, use all keypoints and disable the adaptive mechanisms:
|
92 |
```python
|
93 |
extractor = SuperPoint(max_num_keypoints=None)
|
@@ -99,31 +110,62 @@ To increase the speed with a small drop of accuracy, decrease the number of keyp
|
|
99 |
extractor = SuperPoint(max_num_keypoints=1024)
|
100 |
matcher = LightGlue(features='superpoint', depth_confidence=0.9, width_confidence=0.95)
|
101 |
```
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
<details>
|
105 |
-
<summary>[
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
- [```width_confidence```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L266): Controls the iterative point pruning. A lower value prunes more points earlier. Default: 0.99, disable with -1.
|
112 |
-
- [```filter_threshold```](https://github.com/cvg/LightGlue/blob/main/lightglue/lightglue.py#L267): Match confidence. Increase this value to obtain less, but stronger matches. Default: 0.1
|
113 |
|
114 |
</details>
|
115 |
|
|
|
|
|
|
|
|
|
|
|
116 |
## Other links
|
117 |
- [hloc - the visual localization toolbox](https://github.com/cvg/Hierarchical-Localization/): run LightGlue for Structure-from-Motion and visual localization.
|
118 |
-
- [LightGlue-ONNX](https://github.com/fabio-sim/LightGlue-ONNX): export LightGlue to the Open Neural Network Exchange format.
|
119 |
- [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui): a web GUI to easily compare different matchers, including LightGlue.
|
120 |
-
- [kornia](kornia.readthedocs.io
|
121 |
|
122 |
-
## BibTeX
|
123 |
If you use any ideas from the paper or code from this repo, please consider citing:
|
124 |
|
125 |
```txt
|
126 |
-
@inproceedings{
|
127 |
author = {Philipp Lindenberger and
|
128 |
Paul-Edouard Sarlin and
|
129 |
Marc Pollefeys},
|
@@ -132,3 +174,7 @@ If you use any ideas from the paper or code from this repo, please consider citi
|
|
132 |
year = {2023}
|
133 |
}
|
134 |
```
|
|
|
|
|
|
|
|
|
|
1 |
<p align="center">
|
2 |
+
<h1 align="center"><ins>LightGlue</ins> ⚡️<br>Local Feature Matching at Light Speed</h1>
|
3 |
<p align="center">
|
4 |
<a href="https://www.linkedin.com/in/philipplindenberger/">Philipp Lindenberger</a>
|
5 |
·
|
|
|
7 |
·
|
8 |
<a href="https://www.microsoft.com/en-us/research/people/mapoll/">Marc Pollefeys</a>
|
9 |
</p>
|
10 |
+
<h2 align="center">
|
11 |
+
<p>ICCV 2023</p>
|
|
|
|
|
|
|
12 |
<a href="https://arxiv.org/pdf/2306.13643.pdf" align="center">Paper</a> |
|
13 |
+
<a href="https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb" align="center">Colab</a> |
|
14 |
+
<a href="https://psarlin.com/assets/LightGlue_ICCV2023_poster_compressed.pdf" align="center">Poster</a> |
|
15 |
+
<a href="https://github.com/cvg/glue-factory" align="center">Train your own!</a>
|
16 |
+
</h2>
|
17 |
+
|
18 |
</p>
|
19 |
<p align="center">
|
20 |
<a href="https://arxiv.org/abs/2306.13643"><img src="assets/easy_hard.jpg" alt="example" width=80%></a>
|
|
|
26 |
|
27 |
This repository hosts the inference code of LightGlue, a lightweight feature matcher with high accuracy and blazing fast inference. It takes as input a set of keypoints and descriptors for each image and returns the indices of corresponding points. The architecture is based on adaptive pruning techniques, in both network width and depth - [check out the paper for more details](https://arxiv.org/pdf/2306.13643.pdf).
|
28 |
|
29 |
+
We release pretrained weights of LightGlue with [SuperPoint](https://arxiv.org/abs/1712.07629), [DISK](https://arxiv.org/abs/2006.13566), [ALIKED](https://arxiv.org/abs/2304.03608) and [SIFT](https://www.cs.ubc.ca/~lowe/papers/ijcv04.pdf) local features.
|
30 |
+
The training and evaluation code can be found in our library [glue-factory](https://github.com/cvg/glue-factory/).
|
31 |
|
32 |
## Installation and demo [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cvg/LightGlue/blob/main/demo.ipynb)
|
33 |
|
|
|
43 |
Here is a minimal script to match two images:
|
44 |
|
45 |
```python
|
46 |
+
from lightglue import LightGlue, SuperPoint, DISK, SIFT, ALIKED, DoGHardNet
|
47 |
from lightglue.utils import load_image, rbd
|
48 |
|
49 |
# SuperPoint+LightGlue
|
50 |
extractor = SuperPoint(max_num_keypoints=2048).eval().cuda() # load the extractor
|
51 |
matcher = LightGlue(features='superpoint').eval().cuda() # load the matcher
|
52 |
|
53 |
+
# or DISK+LightGlue, ALIKED+LightGlue or SIFT+LightGlue
|
54 |
extractor = DISK(max_num_keypoints=2048).eval().cuda() # load the extractor
|
55 |
matcher = LightGlue(features='disk').eval().cuda() # load the matcher
|
56 |
|
|
|
87 |
|
88 |
## Advanced configuration
|
89 |
|
90 |
+
<details>
|
91 |
+
<summary>[Detail of all parameters - click to expand]</summary>
|
92 |
+
|
93 |
+
- ```n_layers```: Number of stacked self+cross attention layers. Reduce this value for faster inference at the cost of accuracy (continuous red line in the plot above). Default: 9 (all layers).
|
94 |
+
- ```flash```: Enable FlashAttention. Significantly increases the speed and reduces the memory consumption without any impact on accuracy. Default: True (LightGlue automatically detects if FlashAttention is available).
|
95 |
+
- ```mp```: Enable mixed precision inference. Default: False (off)
|
96 |
+
- ```depth_confidence```: Controls the early stopping. A lower values stops more often at earlier layers. Default: 0.95, disable with -1.
|
97 |
+
- ```width_confidence```: Controls the iterative point pruning. A lower value prunes more points earlier. Default: 0.99, disable with -1.
|
98 |
+
- ```filter_threshold```: Match confidence. Increase this value to obtain less, but stronger matches. Default: 0.1
|
99 |
+
|
100 |
+
</details>
|
101 |
+
|
102 |
The default values give a good trade-off between speed and accuracy. To maximize the accuracy, use all keypoints and disable the adaptive mechanisms:
|
103 |
```python
|
104 |
extractor = SuperPoint(max_num_keypoints=None)
|
|
|
110 |
extractor = SuperPoint(max_num_keypoints=1024)
|
111 |
matcher = LightGlue(features='superpoint', depth_confidence=0.9, width_confidence=0.95)
|
112 |
```
|
113 |
+
|
114 |
+
The maximum speed is obtained with a combination of:
|
115 |
+
- [FlashAttention](https://arxiv.org/abs/2205.14135): automatically used when ```torch >= 2.0``` or if [installed from source](https://github.com/HazyResearch/flash-attention#installation-and-features).
|
116 |
+
- PyTorch compilation, available when ```torch >= 2.0```:
|
117 |
+
```python
|
118 |
+
matcher = matcher.eval().cuda()
|
119 |
+
matcher.compile(mode='reduce-overhead')
|
120 |
+
```
|
121 |
+
For inputs with fewer than 1536 keypoints (determined experimentally), this compiles LightGlue but disables point pruning (large overhead). For larger input sizes, it automatically falls backs to eager mode with point pruning. Adaptive depths is supported for any input size.
|
122 |
+
|
123 |
+
## Benchmark
|
124 |
+
|
125 |
+
|
126 |
+
<p align="center">
|
127 |
+
<a><img src="assets/benchmark.png" alt="Logo" width=80%></a>
|
128 |
+
<br>
|
129 |
+
<em>Benchmark results on GPU (RTX 3080). With compilation and adaptivity, LightGlue runs at 150 FPS @ 1024 keypoints and 50 FPS @ 4096 keypoints per image. This is a 4-10x speedup over SuperGlue. </em>
|
130 |
+
</p>
|
131 |
+
|
132 |
+
<p align="center">
|
133 |
+
<a><img src="assets/benchmark_cpu.png" alt="Logo" width=80%></a>
|
134 |
+
<br>
|
135 |
+
<em>Benchmark results on CPU (Intel i7 10700K). LightGlue runs at 20 FPS @ 512 keypoints. </em>
|
136 |
+
</p>
|
137 |
+
|
138 |
+
Obtain the same plots for your setup using our [benchmark script](benchmark.py):
|
139 |
+
```
|
140 |
+
python benchmark.py [--device cuda] [--add_superglue] [--num_keypoints 512 1024 2048 4096] [--compile]
|
141 |
+
```
|
142 |
|
143 |
<details>
|
144 |
+
<summary>[Performance tip - click to expand]</summary>
|
145 |
|
146 |
+
Note: **Point pruning** introduces an overhead that sometimes outweighs its benefits.
|
147 |
+
Point pruning is thus enabled only when the there are more than N keypoints in an image, where N is hardware-dependent.
|
148 |
+
We provide defaults optimized for current hardware (RTX 30xx GPUs).
|
149 |
+
We suggest running the benchmark script and adjusting the thresholds for your hardware by updating `LightGlue.pruning_keypoint_thresholds['cuda']`.
|
|
|
|
|
150 |
|
151 |
</details>
|
152 |
|
153 |
+
## Training and evaluation
|
154 |
+
|
155 |
+
With [Glue Factory](https://github.com/cvg/glue-factory), you can train LightGlue with your own local features, on your own dataset!
|
156 |
+
You can also evaluate it and other baselines on standard benchmarks like HPatches and MegaDepth.
|
157 |
+
|
158 |
## Other links
|
159 |
- [hloc - the visual localization toolbox](https://github.com/cvg/Hierarchical-Localization/): run LightGlue for Structure-from-Motion and visual localization.
|
160 |
+
- [LightGlue-ONNX](https://github.com/fabio-sim/LightGlue-ONNX): export LightGlue to the Open Neural Network Exchange (ONNX) format with support for TensorRT and OpenVINO.
|
161 |
- [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui): a web GUI to easily compare different matchers, including LightGlue.
|
162 |
+
- [kornia](https://kornia.readthedocs.io) now exposes LightGlue via the interfaces [`LightGlue`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlue) and [`LightGlueMatcher`](https://kornia.readthedocs.io/en/latest/feature.html#kornia.feature.LightGlueMatcher).
|
163 |
|
164 |
+
## BibTeX citation
|
165 |
If you use any ideas from the paper or code from this repo, please consider citing:
|
166 |
|
167 |
```txt
|
168 |
+
@inproceedings{lindenberger2023lightglue,
|
169 |
author = {Philipp Lindenberger and
|
170 |
Paul-Edouard Sarlin and
|
171 |
Marc Pollefeys},
|
|
|
174 |
year = {2023}
|
175 |
}
|
176 |
```
|
177 |
+
|
178 |
+
|
179 |
+
## License
|
180 |
+
The pre-trained weights of LightGlue and the code provided in this repository are released under the [Apache-2.0 license](./LICENSE). [DISK](https://github.com/cvlab-epfl/disk) follows this license as well but SuperPoint follows [a different, restrictive license](https://github.com/magicleap/SuperPointPretrainedNetwork/blob/master/LICENSE) (this includes its pre-trained weights and its [inference file](./lightglue/superpoint.py)). [ALIKED](https://github.com/Shiaoming/ALIKED) was published under a BSD-3-Clause license.
|
third_party/LightGlue/assets/DSC_0410.JPG
CHANGED
Git LFS Details
|
third_party/LightGlue/assets/DSC_0411.JPG
CHANGED
Git LFS Details
|
third_party/LightGlue/assets/benchmark.png
ADDED
Git LFS Details
|
third_party/LightGlue/assets/benchmark_cpu.png
ADDED
Git LFS Details
|
third_party/LightGlue/benchmark.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Benchmark script for LightGlue on real images
|
2 |
+
import argparse
|
3 |
+
import time
|
4 |
+
from collections import defaultdict
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
import torch._dynamo
|
11 |
+
|
12 |
+
from lightglue import LightGlue, SuperPoint
|
13 |
+
from lightglue.utils import load_image
|
14 |
+
|
15 |
+
torch.set_grad_enabled(False)
|
16 |
+
|
17 |
+
|
18 |
+
def measure(matcher, data, device="cuda", r=100):
|
19 |
+
timings = np.zeros((r, 1))
|
20 |
+
if device.type == "cuda":
|
21 |
+
starter = torch.cuda.Event(enable_timing=True)
|
22 |
+
ender = torch.cuda.Event(enable_timing=True)
|
23 |
+
# warmup
|
24 |
+
for _ in range(10):
|
25 |
+
_ = matcher(data)
|
26 |
+
# measurements
|
27 |
+
with torch.no_grad():
|
28 |
+
for rep in range(r):
|
29 |
+
if device.type == "cuda":
|
30 |
+
starter.record()
|
31 |
+
_ = matcher(data)
|
32 |
+
ender.record()
|
33 |
+
# sync gpu
|
34 |
+
torch.cuda.synchronize()
|
35 |
+
curr_time = starter.elapsed_time(ender)
|
36 |
+
else:
|
37 |
+
start = time.perf_counter()
|
38 |
+
_ = matcher(data)
|
39 |
+
curr_time = (time.perf_counter() - start) * 1e3
|
40 |
+
timings[rep] = curr_time
|
41 |
+
mean_syn = np.sum(timings) / r
|
42 |
+
std_syn = np.std(timings)
|
43 |
+
return {"mean": mean_syn, "std": std_syn}
|
44 |
+
|
45 |
+
|
46 |
+
def print_as_table(d, title, cnames):
|
47 |
+
print()
|
48 |
+
header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
|
49 |
+
print(header)
|
50 |
+
print("-" * len(header))
|
51 |
+
for k, l in d.items():
|
52 |
+
print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))
|
53 |
+
|
54 |
+
|
55 |
+
if __name__ == "__main__":
|
56 |
+
parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
|
57 |
+
parser.add_argument(
|
58 |
+
"--device",
|
59 |
+
choices=["auto", "cuda", "cpu", "mps"],
|
60 |
+
default="auto",
|
61 |
+
help="device to benchmark on",
|
62 |
+
)
|
63 |
+
parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
|
64 |
+
parser.add_argument(
|
65 |
+
"--no_flash", action="store_true", help="disable FlashAttention"
|
66 |
+
)
|
67 |
+
parser.add_argument(
|
68 |
+
"--no_prune_thresholds",
|
69 |
+
action="store_true",
|
70 |
+
help="disable pruning thresholds (i.e. always do pruning)",
|
71 |
+
)
|
72 |
+
parser.add_argument(
|
73 |
+
"--add_superglue",
|
74 |
+
action="store_true",
|
75 |
+
help="add SuperGlue to the benchmark (requires hloc)",
|
76 |
+
)
|
77 |
+
parser.add_argument(
|
78 |
+
"--measure", default="time", choices=["time", "log-time", "throughput"]
|
79 |
+
)
|
80 |
+
parser.add_argument(
|
81 |
+
"--repeat", "--r", type=int, default=100, help="repetitions of measurements"
|
82 |
+
)
|
83 |
+
parser.add_argument(
|
84 |
+
"--num_keypoints",
|
85 |
+
nargs="+",
|
86 |
+
type=int,
|
87 |
+
default=[256, 512, 1024, 2048, 4096],
|
88 |
+
help="number of keypoints (list separated by spaces)",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--matmul_precision", default="highest", choices=["highest", "high", "medium"]
|
92 |
+
)
|
93 |
+
parser.add_argument(
|
94 |
+
"--save", default=None, type=str, help="path where figure should be saved"
|
95 |
+
)
|
96 |
+
args = parser.parse_intermixed_args()
|
97 |
+
|
98 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
99 |
+
if args.device != "auto":
|
100 |
+
device = torch.device(args.device)
|
101 |
+
|
102 |
+
print("Running benchmark on device:", device)
|
103 |
+
|
104 |
+
images = Path("assets")
|
105 |
+
inputs = {
|
106 |
+
"easy": (
|
107 |
+
load_image(images / "DSC_0411.JPG"),
|
108 |
+
load_image(images / "DSC_0410.JPG"),
|
109 |
+
),
|
110 |
+
"difficult": (
|
111 |
+
load_image(images / "sacre_coeur1.jpg"),
|
112 |
+
load_image(images / "sacre_coeur2.jpg"),
|
113 |
+
),
|
114 |
+
}
|
115 |
+
|
116 |
+
configs = {
|
117 |
+
"LightGlue-full": {
|
118 |
+
"depth_confidence": -1,
|
119 |
+
"width_confidence": -1,
|
120 |
+
},
|
121 |
+
# 'LG-prune': {
|
122 |
+
# 'width_confidence': -1,
|
123 |
+
# },
|
124 |
+
# 'LG-depth': {
|
125 |
+
# 'depth_confidence': -1,
|
126 |
+
# },
|
127 |
+
"LightGlue-adaptive": {},
|
128 |
+
}
|
129 |
+
|
130 |
+
if args.compile:
|
131 |
+
configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}
|
132 |
+
|
133 |
+
sg_configs = {
|
134 |
+
# 'SuperGlue': {},
|
135 |
+
"SuperGlue-fast": {"sinkhorn_iterations": 5}
|
136 |
+
}
|
137 |
+
|
138 |
+
torch.set_float32_matmul_precision(args.matmul_precision)
|
139 |
+
|
140 |
+
results = {k: defaultdict(list) for k, v in inputs.items()}
|
141 |
+
|
142 |
+
extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
|
143 |
+
extractor = extractor.eval().to(device)
|
144 |
+
figsize = (len(inputs) * 4.5, 4.5)
|
145 |
+
fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
|
146 |
+
axes = axes if len(inputs) > 1 else [axes]
|
147 |
+
fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")
|
148 |
+
|
149 |
+
for title, ax in zip(inputs.keys(), axes):
|
150 |
+
ax.set_xscale("log", base=2)
|
151 |
+
bases = [2**x for x in range(7, 16)]
|
152 |
+
ax.set_xticks(bases, bases)
|
153 |
+
ax.grid(which="major")
|
154 |
+
if args.measure == "log-time":
|
155 |
+
ax.set_yscale("log")
|
156 |
+
yticks = [10**x for x in range(6)]
|
157 |
+
ax.set_yticks(yticks, yticks)
|
158 |
+
mpos = [10**x * i for x in range(6) for i in range(2, 10)]
|
159 |
+
mlabel = [
|
160 |
+
10**x * i if i in [2, 5] else None
|
161 |
+
for x in range(6)
|
162 |
+
for i in range(2, 10)
|
163 |
+
]
|
164 |
+
ax.set_yticks(mpos, mlabel, minor=True)
|
165 |
+
ax.grid(which="minor", linewidth=0.2)
|
166 |
+
ax.set_title(title)
|
167 |
+
|
168 |
+
ax.set_xlabel("# keypoints")
|
169 |
+
if args.measure == "throughput":
|
170 |
+
ax.set_ylabel("Throughput [pairs/s]")
|
171 |
+
else:
|
172 |
+
ax.set_ylabel("Latency [ms]")
|
173 |
+
|
174 |
+
for name, conf in configs.items():
|
175 |
+
print("Run benchmark for:", name)
|
176 |
+
torch.cuda.empty_cache()
|
177 |
+
matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
|
178 |
+
if args.no_prune_thresholds:
|
179 |
+
matcher.pruning_keypoint_thresholds = {
|
180 |
+
k: -1 for k in matcher.pruning_keypoint_thresholds
|
181 |
+
}
|
182 |
+
matcher = matcher.eval().to(device)
|
183 |
+
if name.endswith("compile"):
|
184 |
+
import torch._dynamo
|
185 |
+
|
186 |
+
torch._dynamo.reset() # avoid buffer overflow
|
187 |
+
matcher.compile()
|
188 |
+
for pair_name, ax in zip(inputs.keys(), axes):
|
189 |
+
image0, image1 = [x.to(device) for x in inputs[pair_name]]
|
190 |
+
runtimes = []
|
191 |
+
for num_kpts in args.num_keypoints:
|
192 |
+
extractor.conf.max_num_keypoints = num_kpts
|
193 |
+
feats0 = extractor.extract(image0)
|
194 |
+
feats1 = extractor.extract(image1)
|
195 |
+
runtime = measure(
|
196 |
+
matcher,
|
197 |
+
{"image0": feats0, "image1": feats1},
|
198 |
+
device=device,
|
199 |
+
r=args.repeat,
|
200 |
+
)["mean"]
|
201 |
+
results[pair_name][name].append(
|
202 |
+
1000 / runtime if args.measure == "throughput" else runtime
|
203 |
+
)
|
204 |
+
ax.plot(
|
205 |
+
args.num_keypoints, results[pair_name][name], label=name, marker="o"
|
206 |
+
)
|
207 |
+
del matcher, feats0, feats1
|
208 |
+
|
209 |
+
if args.add_superglue:
|
210 |
+
from hloc.matchers.superglue import SuperGlue
|
211 |
+
|
212 |
+
for name, conf in sg_configs.items():
|
213 |
+
print("Run benchmark for:", name)
|
214 |
+
matcher = SuperGlue(conf)
|
215 |
+
matcher = matcher.eval().to(device)
|
216 |
+
for pair_name, ax in zip(inputs.keys(), axes):
|
217 |
+
image0, image1 = [x.to(device) for x in inputs[pair_name]]
|
218 |
+
runtimes = []
|
219 |
+
for num_kpts in args.num_keypoints:
|
220 |
+
extractor.conf.max_num_keypoints = num_kpts
|
221 |
+
feats0 = extractor.extract(image0)
|
222 |
+
feats1 = extractor.extract(image1)
|
223 |
+
data = {
|
224 |
+
"image0": image0[None],
|
225 |
+
"image1": image1[None],
|
226 |
+
**{k + "0": v for k, v in feats0.items()},
|
227 |
+
**{k + "1": v for k, v in feats1.items()},
|
228 |
+
}
|
229 |
+
data["scores0"] = data["keypoint_scores0"]
|
230 |
+
data["scores1"] = data["keypoint_scores1"]
|
231 |
+
data["descriptors0"] = (
|
232 |
+
data["descriptors0"].transpose(-1, -2).contiguous()
|
233 |
+
)
|
234 |
+
data["descriptors1"] = (
|
235 |
+
data["descriptors1"].transpose(-1, -2).contiguous()
|
236 |
+
)
|
237 |
+
runtime = measure(matcher, data, device=device, r=args.repeat)[
|
238 |
+
"mean"
|
239 |
+
]
|
240 |
+
results[pair_name][name].append(
|
241 |
+
1000 / runtime if args.measure == "throughput" else runtime
|
242 |
+
)
|
243 |
+
ax.plot(
|
244 |
+
args.num_keypoints, results[pair_name][name], label=name, marker="o"
|
245 |
+
)
|
246 |
+
del matcher, data, image0, image1, feats0, feats1
|
247 |
+
|
248 |
+
for name, runtimes in results.items():
|
249 |
+
print_as_table(runtimes, name, args.num_keypoints)
|
250 |
+
|
251 |
+
axes[0].legend()
|
252 |
+
fig.tight_layout()
|
253 |
+
if args.save:
|
254 |
+
plt.savefig(args.save, dpi=fig.dpi)
|
255 |
+
plt.show()
|
third_party/LightGlue/demo.ipynb
CHANGED
@@ -16,16 +16,19 @@
|
|
16 |
"source": [
|
17 |
"# If we are on colab: this clones the repo and installs the dependencies\n",
|
18 |
"from pathlib import Path\n",
|
19 |
-
"
|
20 |
-
"
|
21 |
-
"
|
22 |
-
"
|
23 |
-
"
|
|
|
24 |
"from lightglue import LightGlue, SuperPoint, DISK\n",
|
25 |
"from lightglue.utils import load_image, rbd\n",
|
26 |
"from lightglue import viz2d\n",
|
27 |
"import torch\n",
|
28 |
-
"
|
|
|
|
|
29 |
]
|
30 |
},
|
31 |
{
|
@@ -51,10 +54,10 @@
|
|
51 |
}
|
52 |
],
|
53 |
"source": [
|
54 |
-
"device = torch.device(
|
55 |
"\n",
|
56 |
"extractor = SuperPoint(max_num_keypoints=2048).eval().to(device) # load the extractor\n",
|
57 |
-
"matcher = LightGlue(features
|
58 |
]
|
59 |
},
|
60 |
{
|
@@ -92,22 +95,24 @@
|
|
92 |
}
|
93 |
],
|
94 |
"source": [
|
95 |
-
"image0 = load_image(images /
|
96 |
-
"image1 = load_image(images /
|
97 |
"\n",
|
98 |
"feats0 = extractor.extract(image0.to(device))\n",
|
99 |
"feats1 = extractor.extract(image1.to(device))\n",
|
100 |
-
"matches01 = matcher({
|
101 |
-
"feats0, feats1, matches01 = [
|
|
|
|
|
102 |
"\n",
|
103 |
-
"kpts0, kpts1, matches = feats0[
|
104 |
"m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]\n",
|
105 |
"\n",
|
106 |
"axes = viz2d.plot_images([image0, image1])\n",
|
107 |
-
"viz2d.plot_matches(m_kpts0, m_kpts1, color
|
108 |
"viz2d.add_text(0, f'Stop after {matches01[\"stop\"]} layers', fs=20)\n",
|
109 |
"\n",
|
110 |
-
"kpc0, kpc1 = viz2d.cm_prune(matches01[
|
111 |
"viz2d.plot_images([image0, image1])\n",
|
112 |
"viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=10)"
|
113 |
]
|
@@ -147,22 +152,24 @@
|
|
147 |
}
|
148 |
],
|
149 |
"source": [
|
150 |
-
"image0 = load_image(images /
|
151 |
-
"image1 = load_image(images /
|
152 |
"\n",
|
153 |
"feats0 = extractor.extract(image0.to(device))\n",
|
154 |
"feats1 = extractor.extract(image1.to(device))\n",
|
155 |
-
"matches01 = matcher({
|
156 |
-
"feats0, feats1, matches01 = [
|
|
|
|
|
157 |
"\n",
|
158 |
-
"kpts0, kpts1, matches = feats0[
|
159 |
"m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]\n",
|
160 |
"\n",
|
161 |
"axes = viz2d.plot_images([image0, image1])\n",
|
162 |
-
"viz2d.plot_matches(m_kpts0, m_kpts1, color
|
163 |
"viz2d.add_text(0, f'Stop after {matches01[\"stop\"]} layers')\n",
|
164 |
"\n",
|
165 |
-
"kpc0, kpc1 = viz2d.cm_prune(matches01[
|
166 |
"viz2d.plot_images([image0, image1])\n",
|
167 |
"viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=6)"
|
168 |
]
|
|
|
16 |
"source": [
|
17 |
"# If we are on colab: this clones the repo and installs the dependencies\n",
|
18 |
"from pathlib import Path\n",
|
19 |
+
"\n",
|
20 |
+
"if Path.cwd().name != \"LightGlue\":\n",
|
21 |
+
" !git clone --quiet https://github.com/cvg/LightGlue/\n",
|
22 |
+
" %cd LightGlue\n",
|
23 |
+
" !pip install --progress-bar off --quiet -e .\n",
|
24 |
+
"\n",
|
25 |
"from lightglue import LightGlue, SuperPoint, DISK\n",
|
26 |
"from lightglue.utils import load_image, rbd\n",
|
27 |
"from lightglue import viz2d\n",
|
28 |
"import torch\n",
|
29 |
+
"\n",
|
30 |
+
"torch.set_grad_enabled(False)\n",
|
31 |
+
"images = Path(\"assets\")"
|
32 |
]
|
33 |
},
|
34 |
{
|
|
|
54 |
}
|
55 |
],
|
56 |
"source": [
|
57 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\") # 'mps', 'cpu'\n",
|
58 |
"\n",
|
59 |
"extractor = SuperPoint(max_num_keypoints=2048).eval().to(device) # load the extractor\n",
|
60 |
+
"matcher = LightGlue(features=\"superpoint\").eval().to(device)"
|
61 |
]
|
62 |
},
|
63 |
{
|
|
|
95 |
}
|
96 |
],
|
97 |
"source": [
|
98 |
+
"image0 = load_image(images / \"DSC_0411.JPG\")\n",
|
99 |
+
"image1 = load_image(images / \"DSC_0410.JPG\")\n",
|
100 |
"\n",
|
101 |
"feats0 = extractor.extract(image0.to(device))\n",
|
102 |
"feats1 = extractor.extract(image1.to(device))\n",
|
103 |
+
"matches01 = matcher({\"image0\": feats0, \"image1\": feats1})\n",
|
104 |
+
"feats0, feats1, matches01 = [\n",
|
105 |
+
" rbd(x) for x in [feats0, feats1, matches01]\n",
|
106 |
+
"] # remove batch dimension\n",
|
107 |
"\n",
|
108 |
+
"kpts0, kpts1, matches = feats0[\"keypoints\"], feats1[\"keypoints\"], matches01[\"matches\"]\n",
|
109 |
"m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]\n",
|
110 |
"\n",
|
111 |
"axes = viz2d.plot_images([image0, image1])\n",
|
112 |
+
"viz2d.plot_matches(m_kpts0, m_kpts1, color=\"lime\", lw=0.2)\n",
|
113 |
"viz2d.add_text(0, f'Stop after {matches01[\"stop\"]} layers', fs=20)\n",
|
114 |
"\n",
|
115 |
+
"kpc0, kpc1 = viz2d.cm_prune(matches01[\"prune0\"]), viz2d.cm_prune(matches01[\"prune1\"])\n",
|
116 |
"viz2d.plot_images([image0, image1])\n",
|
117 |
"viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=10)"
|
118 |
]
|
|
|
152 |
}
|
153 |
],
|
154 |
"source": [
|
155 |
+
"image0 = load_image(images / \"sacre_coeur1.jpg\")\n",
|
156 |
+
"image1 = load_image(images / \"sacre_coeur2.jpg\")\n",
|
157 |
"\n",
|
158 |
"feats0 = extractor.extract(image0.to(device))\n",
|
159 |
"feats1 = extractor.extract(image1.to(device))\n",
|
160 |
+
"matches01 = matcher({\"image0\": feats0, \"image1\": feats1})\n",
|
161 |
+
"feats0, feats1, matches01 = [\n",
|
162 |
+
" rbd(x) for x in [feats0, feats1, matches01]\n",
|
163 |
+
"] # remove batch dimension\n",
|
164 |
"\n",
|
165 |
+
"kpts0, kpts1, matches = feats0[\"keypoints\"], feats1[\"keypoints\"], matches01[\"matches\"]\n",
|
166 |
"m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]\n",
|
167 |
"\n",
|
168 |
"axes = viz2d.plot_images([image0, image1])\n",
|
169 |
+
"viz2d.plot_matches(m_kpts0, m_kpts1, color=\"lime\", lw=0.2)\n",
|
170 |
"viz2d.add_text(0, f'Stop after {matches01[\"stop\"]} layers')\n",
|
171 |
"\n",
|
172 |
+
"kpc0, kpc1 = viz2d.cm_prune(matches01[\"prune0\"]), viz2d.cm_prune(matches01[\"prune1\"])\n",
|
173 |
"viz2d.plot_images([image0, image1])\n",
|
174 |
"viz2d.plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=6)"
|
175 |
]
|
third_party/LightGlue/lightglue/__init__.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1 |
-
from .
|
2 |
-
from .
|
3 |
-
from .
|
4 |
-
from .
|
|
|
|
|
|
|
|
1 |
+
from .aliked import ALIKED # noqa
|
2 |
+
from .disk import DISK # noqa
|
3 |
+
from .dog_hardnet import DoGHardNet # noqa
|
4 |
+
from .lightglue import LightGlue # noqa
|
5 |
+
from .sift import SIFT # noqa
|
6 |
+
from .superpoint import SuperPoint # noqa
|
7 |
+
from .utils import match_pair # noqa
|
third_party/LightGlue/lightglue/aliked.py
ADDED
@@ -0,0 +1,758 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# BSD 3-Clause License
|
2 |
+
|
3 |
+
# Copyright (c) 2022, Zhao Xiaoming
|
4 |
+
# All rights reserved.
|
5 |
+
|
6 |
+
# Redistribution and use in source and binary forms, with or without
|
7 |
+
# modification, are permitted provided that the following conditions are met:
|
8 |
+
|
9 |
+
# 1. Redistributions of source code must retain the above copyright notice, this
|
10 |
+
# list of conditions and the following disclaimer.
|
11 |
+
|
12 |
+
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
13 |
+
# this list of conditions and the following disclaimer in the documentation
|
14 |
+
# and/or other materials provided with the distribution.
|
15 |
+
|
16 |
+
# 3. Neither the name of the copyright holder nor the names of its
|
17 |
+
# contributors may be used to endorse or promote products derived from
|
18 |
+
# this software without specific prior written permission.
|
19 |
+
|
20 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
21 |
+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
22 |
+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
23 |
+
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
24 |
+
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
25 |
+
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
26 |
+
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
27 |
+
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
28 |
+
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
29 |
+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
30 |
+
|
31 |
+
# Authors:
|
32 |
+
# Xiaoming Zhao, Xingming Wu, Weihai Chen, Peter C.Y. Chen, Qingsong Xu, and Zhengguo Li
|
33 |
+
# Code from https://github.com/Shiaoming/ALIKED
|
34 |
+
|
35 |
+
from typing import Callable, Optional
|
36 |
+
|
37 |
+
import torch
|
38 |
+
import torch.nn.functional as F
|
39 |
+
import torchvision
|
40 |
+
from kornia.color import grayscale_to_rgb
|
41 |
+
from torch import nn
|
42 |
+
from torch.nn.modules.utils import _pair
|
43 |
+
from torchvision.models import resnet
|
44 |
+
|
45 |
+
from .utils import Extractor
|
46 |
+
|
47 |
+
|
48 |
+
def get_patches(
|
49 |
+
tensor: torch.Tensor, required_corners: torch.Tensor, ps: int
|
50 |
+
) -> torch.Tensor:
|
51 |
+
c, h, w = tensor.shape
|
52 |
+
corner = (required_corners - ps / 2 + 1).long()
|
53 |
+
corner[:, 0] = corner[:, 0].clamp(min=0, max=w - 1 - ps)
|
54 |
+
corner[:, 1] = corner[:, 1].clamp(min=0, max=h - 1 - ps)
|
55 |
+
offset = torch.arange(0, ps)
|
56 |
+
|
57 |
+
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
|
58 |
+
x, y = torch.meshgrid(offset, offset, **kw)
|
59 |
+
patches = torch.stack((x, y)).permute(2, 1, 0).unsqueeze(2)
|
60 |
+
patches = patches.to(corner) + corner[None, None]
|
61 |
+
pts = patches.reshape(-1, 2)
|
62 |
+
sampled = tensor.permute(1, 2, 0)[tuple(pts.T)[::-1]]
|
63 |
+
sampled = sampled.reshape(ps, ps, -1, c)
|
64 |
+
assert sampled.shape[:3] == patches.shape[:3]
|
65 |
+
return sampled.permute(2, 3, 0, 1)
|
66 |
+
|
67 |
+
|
68 |
+
def simple_nms(scores: torch.Tensor, nms_radius: int):
|
69 |
+
"""Fast Non-maximum suppression to remove nearby points"""
|
70 |
+
|
71 |
+
zeros = torch.zeros_like(scores)
|
72 |
+
max_mask = scores == torch.nn.functional.max_pool2d(
|
73 |
+
scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
74 |
+
)
|
75 |
+
|
76 |
+
for _ in range(2):
|
77 |
+
supp_mask = (
|
78 |
+
torch.nn.functional.max_pool2d(
|
79 |
+
max_mask.float(),
|
80 |
+
kernel_size=nms_radius * 2 + 1,
|
81 |
+
stride=1,
|
82 |
+
padding=nms_radius,
|
83 |
+
)
|
84 |
+
> 0
|
85 |
+
)
|
86 |
+
supp_scores = torch.where(supp_mask, zeros, scores)
|
87 |
+
new_max_mask = supp_scores == torch.nn.functional.max_pool2d(
|
88 |
+
supp_scores, kernel_size=nms_radius * 2 + 1, stride=1, padding=nms_radius
|
89 |
+
)
|
90 |
+
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
91 |
+
return torch.where(max_mask, scores, zeros)
|
92 |
+
|
93 |
+
|
94 |
+
class DKD(nn.Module):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
radius: int = 2,
|
98 |
+
top_k: int = 0,
|
99 |
+
scores_th: float = 0.2,
|
100 |
+
n_limit: int = 20000,
|
101 |
+
):
|
102 |
+
"""
|
103 |
+
Args:
|
104 |
+
radius: soft detection radius, kernel size is (2 * radius + 1)
|
105 |
+
top_k: top_k > 0: return top k keypoints
|
106 |
+
scores_th: top_k <= 0 threshold mode:
|
107 |
+
scores_th > 0: return keypoints with scores>scores_th
|
108 |
+
else: return keypoints with scores > scores.mean()
|
109 |
+
n_limit: max number of keypoint in threshold mode
|
110 |
+
"""
|
111 |
+
super().__init__()
|
112 |
+
self.radius = radius
|
113 |
+
self.top_k = top_k
|
114 |
+
self.scores_th = scores_th
|
115 |
+
self.n_limit = n_limit
|
116 |
+
self.kernel_size = 2 * self.radius + 1
|
117 |
+
self.temperature = 0.1 # tuned temperature
|
118 |
+
self.unfold = nn.Unfold(kernel_size=self.kernel_size, padding=self.radius)
|
119 |
+
# local xy grid
|
120 |
+
x = torch.linspace(-self.radius, self.radius, self.kernel_size)
|
121 |
+
# (kernel_size*kernel_size) x 2 : (w,h)
|
122 |
+
kw = {"indexing": "ij"} if torch.__version__ >= "1.10" else {}
|
123 |
+
self.hw_grid = (
|
124 |
+
torch.stack(torch.meshgrid([x, x], **kw)).view(2, -1).t()[:, [1, 0]]
|
125 |
+
)
|
126 |
+
|
127 |
+
def forward(
|
128 |
+
self,
|
129 |
+
scores_map: torch.Tensor,
|
130 |
+
sub_pixel: bool = True,
|
131 |
+
image_size: Optional[torch.Tensor] = None,
|
132 |
+
):
|
133 |
+
"""
|
134 |
+
:param scores_map: Bx1xHxW
|
135 |
+
:param descriptor_map: BxCxHxW
|
136 |
+
:param sub_pixel: whether to use sub-pixel keypoint detection
|
137 |
+
:return: kpts: list[Nx2,...]; kptscores: list[N,....] normalised position: -1~1
|
138 |
+
"""
|
139 |
+
b, c, h, w = scores_map.shape
|
140 |
+
scores_nograd = scores_map.detach()
|
141 |
+
nms_scores = simple_nms(scores_nograd, self.radius)
|
142 |
+
|
143 |
+
# remove border
|
144 |
+
nms_scores[:, :, : self.radius, :] = 0
|
145 |
+
nms_scores[:, :, :, : self.radius] = 0
|
146 |
+
if image_size is not None:
|
147 |
+
for i in range(scores_map.shape[0]):
|
148 |
+
w, h = image_size[i].long()
|
149 |
+
nms_scores[i, :, h.item() - self.radius :, :] = 0
|
150 |
+
nms_scores[i, :, :, w.item() - self.radius :] = 0
|
151 |
+
else:
|
152 |
+
nms_scores[:, :, -self.radius :, :] = 0
|
153 |
+
nms_scores[:, :, :, -self.radius :] = 0
|
154 |
+
|
155 |
+
# detect keypoints without grad
|
156 |
+
if self.top_k > 0:
|
157 |
+
topk = torch.topk(nms_scores.view(b, -1), self.top_k)
|
158 |
+
indices_keypoints = [topk.indices[i] for i in range(b)] # B x top_k
|
159 |
+
else:
|
160 |
+
if self.scores_th > 0:
|
161 |
+
masks = nms_scores > self.scores_th
|
162 |
+
if masks.sum() == 0:
|
163 |
+
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
164 |
+
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
165 |
+
else:
|
166 |
+
th = scores_nograd.reshape(b, -1).mean(dim=1) # th = self.scores_th
|
167 |
+
masks = nms_scores > th.reshape(b, 1, 1, 1)
|
168 |
+
masks = masks.reshape(b, -1)
|
169 |
+
|
170 |
+
indices_keypoints = [] # list, B x (any size)
|
171 |
+
scores_view = scores_nograd.reshape(b, -1)
|
172 |
+
for mask, scores in zip(masks, scores_view):
|
173 |
+
indices = mask.nonzero()[:, 0]
|
174 |
+
if len(indices) > self.n_limit:
|
175 |
+
kpts_sc = scores[indices]
|
176 |
+
sort_idx = kpts_sc.sort(descending=True)[1]
|
177 |
+
sel_idx = sort_idx[: self.n_limit]
|
178 |
+
indices = indices[sel_idx]
|
179 |
+
indices_keypoints.append(indices)
|
180 |
+
|
181 |
+
wh = torch.tensor([w - 1, h - 1], device=scores_nograd.device)
|
182 |
+
|
183 |
+
keypoints = []
|
184 |
+
scoredispersitys = []
|
185 |
+
kptscores = []
|
186 |
+
if sub_pixel:
|
187 |
+
# detect soft keypoints with grad backpropagation
|
188 |
+
patches = self.unfold(scores_map) # B x (kernel**2) x (H*W)
|
189 |
+
self.hw_grid = self.hw_grid.to(scores_map) # to device
|
190 |
+
for b_idx in range(b):
|
191 |
+
patch = patches[b_idx].t() # (H*W) x (kernel**2)
|
192 |
+
indices_kpt = indices_keypoints[
|
193 |
+
b_idx
|
194 |
+
] # one dimension vector, say its size is M
|
195 |
+
patch_scores = patch[indices_kpt] # M x (kernel**2)
|
196 |
+
keypoints_xy_nms = torch.stack(
|
197 |
+
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
|
198 |
+
dim=1,
|
199 |
+
) # Mx2
|
200 |
+
|
201 |
+
# max is detached to prevent undesired backprop loops in the graph
|
202 |
+
max_v = patch_scores.max(dim=1).values.detach()[:, None]
|
203 |
+
x_exp = (
|
204 |
+
(patch_scores - max_v) / self.temperature
|
205 |
+
).exp() # M * (kernel**2), in [0, 1]
|
206 |
+
|
207 |
+
# \frac{ \sum{(i,j) \times \exp(x/T)} }{ \sum{\exp(x/T)} }
|
208 |
+
xy_residual = (
|
209 |
+
x_exp @ self.hw_grid / x_exp.sum(dim=1)[:, None]
|
210 |
+
) # Soft-argmax, Mx2
|
211 |
+
|
212 |
+
hw_grid_dist2 = (
|
213 |
+
torch.norm(
|
214 |
+
(self.hw_grid[None, :, :] - xy_residual[:, None, :])
|
215 |
+
/ self.radius,
|
216 |
+
dim=-1,
|
217 |
+
)
|
218 |
+
** 2
|
219 |
+
)
|
220 |
+
scoredispersity = (x_exp * hw_grid_dist2).sum(dim=1) / x_exp.sum(dim=1)
|
221 |
+
|
222 |
+
# compute result keypoints
|
223 |
+
keypoints_xy = keypoints_xy_nms + xy_residual
|
224 |
+
keypoints_xy = keypoints_xy / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
|
225 |
+
|
226 |
+
kptscore = torch.nn.functional.grid_sample(
|
227 |
+
scores_map[b_idx].unsqueeze(0),
|
228 |
+
keypoints_xy.view(1, 1, -1, 2),
|
229 |
+
mode="bilinear",
|
230 |
+
align_corners=True,
|
231 |
+
)[
|
232 |
+
0, 0, 0, :
|
233 |
+
] # CxN
|
234 |
+
|
235 |
+
keypoints.append(keypoints_xy)
|
236 |
+
scoredispersitys.append(scoredispersity)
|
237 |
+
kptscores.append(kptscore)
|
238 |
+
else:
|
239 |
+
for b_idx in range(b):
|
240 |
+
indices_kpt = indices_keypoints[
|
241 |
+
b_idx
|
242 |
+
] # one dimension vector, say its size is M
|
243 |
+
# To avoid warning: UserWarning: __floordiv__ is deprecated
|
244 |
+
keypoints_xy_nms = torch.stack(
|
245 |
+
[indices_kpt % w, torch.div(indices_kpt, w, rounding_mode="trunc")],
|
246 |
+
dim=1,
|
247 |
+
) # Mx2
|
248 |
+
keypoints_xy = keypoints_xy_nms / wh * 2 - 1 # (w,h) -> (-1~1,-1~1)
|
249 |
+
kptscore = torch.nn.functional.grid_sample(
|
250 |
+
scores_map[b_idx].unsqueeze(0),
|
251 |
+
keypoints_xy.view(1, 1, -1, 2),
|
252 |
+
mode="bilinear",
|
253 |
+
align_corners=True,
|
254 |
+
)[
|
255 |
+
0, 0, 0, :
|
256 |
+
] # CxN
|
257 |
+
keypoints.append(keypoints_xy)
|
258 |
+
scoredispersitys.append(kptscore) # for jit.script compatability
|
259 |
+
kptscores.append(kptscore)
|
260 |
+
|
261 |
+
return keypoints, scoredispersitys, kptscores
|
262 |
+
|
263 |
+
|
264 |
+
class InputPadder(object):
|
265 |
+
"""Pads images such that dimensions are divisible by 8"""
|
266 |
+
|
267 |
+
def __init__(self, h: int, w: int, divis_by: int = 8):
|
268 |
+
self.ht = h
|
269 |
+
self.wd = w
|
270 |
+
pad_ht = (((self.ht // divis_by) + 1) * divis_by - self.ht) % divis_by
|
271 |
+
pad_wd = (((self.wd // divis_by) + 1) * divis_by - self.wd) % divis_by
|
272 |
+
self._pad = [
|
273 |
+
pad_wd // 2,
|
274 |
+
pad_wd - pad_wd // 2,
|
275 |
+
pad_ht // 2,
|
276 |
+
pad_ht - pad_ht // 2,
|
277 |
+
]
|
278 |
+
|
279 |
+
def pad(self, x: torch.Tensor):
|
280 |
+
assert x.ndim == 4
|
281 |
+
return F.pad(x, self._pad, mode="replicate")
|
282 |
+
|
283 |
+
def unpad(self, x: torch.Tensor):
|
284 |
+
assert x.ndim == 4
|
285 |
+
ht = x.shape[-2]
|
286 |
+
wd = x.shape[-1]
|
287 |
+
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
288 |
+
return x[..., c[0] : c[1], c[2] : c[3]]
|
289 |
+
|
290 |
+
|
291 |
+
class DeformableConv2d(nn.Module):
|
292 |
+
def __init__(
|
293 |
+
self,
|
294 |
+
in_channels,
|
295 |
+
out_channels,
|
296 |
+
kernel_size=3,
|
297 |
+
stride=1,
|
298 |
+
padding=1,
|
299 |
+
bias=False,
|
300 |
+
mask=False,
|
301 |
+
):
|
302 |
+
super(DeformableConv2d, self).__init__()
|
303 |
+
|
304 |
+
self.padding = padding
|
305 |
+
self.mask = mask
|
306 |
+
|
307 |
+
self.channel_num = (
|
308 |
+
3 * kernel_size * kernel_size if mask else 2 * kernel_size * kernel_size
|
309 |
+
)
|
310 |
+
self.offset_conv = nn.Conv2d(
|
311 |
+
in_channels,
|
312 |
+
self.channel_num,
|
313 |
+
kernel_size=kernel_size,
|
314 |
+
stride=stride,
|
315 |
+
padding=self.padding,
|
316 |
+
bias=True,
|
317 |
+
)
|
318 |
+
|
319 |
+
self.regular_conv = nn.Conv2d(
|
320 |
+
in_channels=in_channels,
|
321 |
+
out_channels=out_channels,
|
322 |
+
kernel_size=kernel_size,
|
323 |
+
stride=stride,
|
324 |
+
padding=self.padding,
|
325 |
+
bias=bias,
|
326 |
+
)
|
327 |
+
|
328 |
+
def forward(self, x):
|
329 |
+
h, w = x.shape[2:]
|
330 |
+
max_offset = max(h, w) / 4.0
|
331 |
+
|
332 |
+
out = self.offset_conv(x)
|
333 |
+
if self.mask:
|
334 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
335 |
+
offset = torch.cat((o1, o2), dim=1)
|
336 |
+
mask = torch.sigmoid(mask)
|
337 |
+
else:
|
338 |
+
offset = out
|
339 |
+
mask = None
|
340 |
+
offset = offset.clamp(-max_offset, max_offset)
|
341 |
+
x = torchvision.ops.deform_conv2d(
|
342 |
+
input=x,
|
343 |
+
offset=offset,
|
344 |
+
weight=self.regular_conv.weight,
|
345 |
+
bias=self.regular_conv.bias,
|
346 |
+
padding=self.padding,
|
347 |
+
mask=mask,
|
348 |
+
)
|
349 |
+
return x
|
350 |
+
|
351 |
+
|
352 |
+
def get_conv(
|
353 |
+
inplanes,
|
354 |
+
planes,
|
355 |
+
kernel_size=3,
|
356 |
+
stride=1,
|
357 |
+
padding=1,
|
358 |
+
bias=False,
|
359 |
+
conv_type="conv",
|
360 |
+
mask=False,
|
361 |
+
):
|
362 |
+
if conv_type == "conv":
|
363 |
+
conv = nn.Conv2d(
|
364 |
+
inplanes,
|
365 |
+
planes,
|
366 |
+
kernel_size=kernel_size,
|
367 |
+
stride=stride,
|
368 |
+
padding=padding,
|
369 |
+
bias=bias,
|
370 |
+
)
|
371 |
+
elif conv_type == "dcn":
|
372 |
+
conv = DeformableConv2d(
|
373 |
+
inplanes,
|
374 |
+
planes,
|
375 |
+
kernel_size=kernel_size,
|
376 |
+
stride=stride,
|
377 |
+
padding=_pair(padding),
|
378 |
+
bias=bias,
|
379 |
+
mask=mask,
|
380 |
+
)
|
381 |
+
else:
|
382 |
+
raise TypeError
|
383 |
+
return conv
|
384 |
+
|
385 |
+
|
386 |
+
class ConvBlock(nn.Module):
|
387 |
+
def __init__(
|
388 |
+
self,
|
389 |
+
in_channels,
|
390 |
+
out_channels,
|
391 |
+
gate: Optional[Callable[..., nn.Module]] = None,
|
392 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
393 |
+
conv_type: str = "conv",
|
394 |
+
mask: bool = False,
|
395 |
+
):
|
396 |
+
super().__init__()
|
397 |
+
if gate is None:
|
398 |
+
self.gate = nn.ReLU(inplace=True)
|
399 |
+
else:
|
400 |
+
self.gate = gate
|
401 |
+
if norm_layer is None:
|
402 |
+
norm_layer = nn.BatchNorm2d
|
403 |
+
self.conv1 = get_conv(
|
404 |
+
in_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
|
405 |
+
)
|
406 |
+
self.bn1 = norm_layer(out_channels)
|
407 |
+
self.conv2 = get_conv(
|
408 |
+
out_channels, out_channels, kernel_size=3, conv_type=conv_type, mask=mask
|
409 |
+
)
|
410 |
+
self.bn2 = norm_layer(out_channels)
|
411 |
+
|
412 |
+
def forward(self, x):
|
413 |
+
x = self.gate(self.bn1(self.conv1(x))) # B x in_channels x H x W
|
414 |
+
x = self.gate(self.bn2(self.conv2(x))) # B x out_channels x H x W
|
415 |
+
return x
|
416 |
+
|
417 |
+
|
418 |
+
# modified based on torchvision\models\resnet.py#27->BasicBlock
|
419 |
+
class ResBlock(nn.Module):
|
420 |
+
expansion: int = 1
|
421 |
+
|
422 |
+
def __init__(
|
423 |
+
self,
|
424 |
+
inplanes: int,
|
425 |
+
planes: int,
|
426 |
+
stride: int = 1,
|
427 |
+
downsample: Optional[nn.Module] = None,
|
428 |
+
groups: int = 1,
|
429 |
+
base_width: int = 64,
|
430 |
+
dilation: int = 1,
|
431 |
+
gate: Optional[Callable[..., nn.Module]] = None,
|
432 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None,
|
433 |
+
conv_type: str = "conv",
|
434 |
+
mask: bool = False,
|
435 |
+
) -> None:
|
436 |
+
super(ResBlock, self).__init__()
|
437 |
+
if gate is None:
|
438 |
+
self.gate = nn.ReLU(inplace=True)
|
439 |
+
else:
|
440 |
+
self.gate = gate
|
441 |
+
if norm_layer is None:
|
442 |
+
norm_layer = nn.BatchNorm2d
|
443 |
+
if groups != 1 or base_width != 64:
|
444 |
+
raise ValueError("ResBlock only supports groups=1 and base_width=64")
|
445 |
+
if dilation > 1:
|
446 |
+
raise NotImplementedError("Dilation > 1 not supported in ResBlock")
|
447 |
+
# Both self.conv1 and self.downsample layers
|
448 |
+
# downsample the input when stride != 1
|
449 |
+
self.conv1 = get_conv(
|
450 |
+
inplanes, planes, kernel_size=3, conv_type=conv_type, mask=mask
|
451 |
+
)
|
452 |
+
self.bn1 = norm_layer(planes)
|
453 |
+
self.conv2 = get_conv(
|
454 |
+
planes, planes, kernel_size=3, conv_type=conv_type, mask=mask
|
455 |
+
)
|
456 |
+
self.bn2 = norm_layer(planes)
|
457 |
+
self.downsample = downsample
|
458 |
+
self.stride = stride
|
459 |
+
|
460 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
461 |
+
identity = x
|
462 |
+
|
463 |
+
out = self.conv1(x)
|
464 |
+
out = self.bn1(out)
|
465 |
+
out = self.gate(out)
|
466 |
+
|
467 |
+
out = self.conv2(out)
|
468 |
+
out = self.bn2(out)
|
469 |
+
|
470 |
+
if self.downsample is not None:
|
471 |
+
identity = self.downsample(x)
|
472 |
+
|
473 |
+
out += identity
|
474 |
+
out = self.gate(out)
|
475 |
+
|
476 |
+
return out
|
477 |
+
|
478 |
+
|
479 |
+
class SDDH(nn.Module):
|
480 |
+
def __init__(
|
481 |
+
self,
|
482 |
+
dims: int,
|
483 |
+
kernel_size: int = 3,
|
484 |
+
n_pos: int = 8,
|
485 |
+
gate=nn.ReLU(),
|
486 |
+
conv2D=False,
|
487 |
+
mask=False,
|
488 |
+
):
|
489 |
+
super(SDDH, self).__init__()
|
490 |
+
self.kernel_size = kernel_size
|
491 |
+
self.n_pos = n_pos
|
492 |
+
self.conv2D = conv2D
|
493 |
+
self.mask = mask
|
494 |
+
|
495 |
+
self.get_patches_func = get_patches
|
496 |
+
|
497 |
+
# estimate offsets
|
498 |
+
self.channel_num = 3 * n_pos if mask else 2 * n_pos
|
499 |
+
self.offset_conv = nn.Sequential(
|
500 |
+
nn.Conv2d(
|
501 |
+
dims,
|
502 |
+
self.channel_num,
|
503 |
+
kernel_size=kernel_size,
|
504 |
+
stride=1,
|
505 |
+
padding=0,
|
506 |
+
bias=True,
|
507 |
+
),
|
508 |
+
gate,
|
509 |
+
nn.Conv2d(
|
510 |
+
self.channel_num,
|
511 |
+
self.channel_num,
|
512 |
+
kernel_size=1,
|
513 |
+
stride=1,
|
514 |
+
padding=0,
|
515 |
+
bias=True,
|
516 |
+
),
|
517 |
+
)
|
518 |
+
|
519 |
+
# sampled feature conv
|
520 |
+
self.sf_conv = nn.Conv2d(
|
521 |
+
dims, dims, kernel_size=1, stride=1, padding=0, bias=False
|
522 |
+
)
|
523 |
+
|
524 |
+
# convM
|
525 |
+
if not conv2D:
|
526 |
+
# deformable desc weights
|
527 |
+
agg_weights = torch.nn.Parameter(torch.rand(n_pos, dims, dims))
|
528 |
+
self.register_parameter("agg_weights", agg_weights)
|
529 |
+
else:
|
530 |
+
self.convM = nn.Conv2d(
|
531 |
+
dims * n_pos, dims, kernel_size=1, stride=1, padding=0, bias=False
|
532 |
+
)
|
533 |
+
|
534 |
+
def forward(self, x, keypoints):
|
535 |
+
# x: [B,C,H,W]
|
536 |
+
# keypoints: list, [[N_kpts,2], ...] (w,h)
|
537 |
+
b, c, h, w = x.shape
|
538 |
+
wh = torch.tensor([[w - 1, h - 1]], device=x.device)
|
539 |
+
max_offset = max(h, w) / 4.0
|
540 |
+
|
541 |
+
offsets = []
|
542 |
+
descriptors = []
|
543 |
+
# get offsets for each keypoint
|
544 |
+
for ib in range(b):
|
545 |
+
xi, kptsi = x[ib], keypoints[ib]
|
546 |
+
kptsi_wh = (kptsi / 2 + 0.5) * wh
|
547 |
+
N_kpts = len(kptsi)
|
548 |
+
|
549 |
+
if self.kernel_size > 1:
|
550 |
+
patch = self.get_patches_func(
|
551 |
+
xi, kptsi_wh.long(), self.kernel_size
|
552 |
+
) # [N_kpts, C, K, K]
|
553 |
+
else:
|
554 |
+
kptsi_wh_long = kptsi_wh.long()
|
555 |
+
patch = (
|
556 |
+
xi[:, kptsi_wh_long[:, 1], kptsi_wh_long[:, 0]]
|
557 |
+
.permute(1, 0)
|
558 |
+
.reshape(N_kpts, c, 1, 1)
|
559 |
+
)
|
560 |
+
|
561 |
+
offset = self.offset_conv(patch).clamp(
|
562 |
+
-max_offset, max_offset
|
563 |
+
) # [N_kpts, 2*n_pos, 1, 1]
|
564 |
+
if self.mask:
|
565 |
+
offset = (
|
566 |
+
offset[:, :, 0, 0].view(N_kpts, 3, self.n_pos).permute(0, 2, 1)
|
567 |
+
) # [N_kpts, n_pos, 3]
|
568 |
+
offset = offset[:, :, :-1] # [N_kpts, n_pos, 2]
|
569 |
+
mask_weight = torch.sigmoid(offset[:, :, -1]) # [N_kpts, n_pos]
|
570 |
+
else:
|
571 |
+
offset = (
|
572 |
+
offset[:, :, 0, 0].view(N_kpts, 2, self.n_pos).permute(0, 2, 1)
|
573 |
+
) # [N_kpts, n_pos, 2]
|
574 |
+
offsets.append(offset) # for visualization
|
575 |
+
|
576 |
+
# get sample positions
|
577 |
+
pos = kptsi_wh.unsqueeze(1) + offset # [N_kpts, n_pos, 2]
|
578 |
+
pos = 2.0 * pos / wh[None] - 1
|
579 |
+
pos = pos.reshape(1, N_kpts * self.n_pos, 1, 2)
|
580 |
+
|
581 |
+
# sample features
|
582 |
+
features = F.grid_sample(
|
583 |
+
xi.unsqueeze(0), pos, mode="bilinear", align_corners=True
|
584 |
+
) # [1,C,(N_kpts*n_pos),1]
|
585 |
+
features = features.reshape(c, N_kpts, self.n_pos, 1).permute(
|
586 |
+
1, 0, 2, 3
|
587 |
+
) # [N_kpts, C, n_pos, 1]
|
588 |
+
if self.mask:
|
589 |
+
features = torch.einsum("ncpo,np->ncpo", features, mask_weight)
|
590 |
+
|
591 |
+
features = torch.selu_(self.sf_conv(features)).squeeze(
|
592 |
+
-1
|
593 |
+
) # [N_kpts, C, n_pos]
|
594 |
+
# convM
|
595 |
+
if not self.conv2D:
|
596 |
+
descs = torch.einsum(
|
597 |
+
"ncp,pcd->nd", features, self.agg_weights
|
598 |
+
) # [N_kpts, C]
|
599 |
+
else:
|
600 |
+
features = features.reshape(N_kpts, -1)[
|
601 |
+
:, :, None, None
|
602 |
+
] # [N_kpts, C*n_pos, 1, 1]
|
603 |
+
descs = self.convM(features).squeeze() # [N_kpts, C]
|
604 |
+
|
605 |
+
# normalize
|
606 |
+
descs = F.normalize(descs, p=2.0, dim=1)
|
607 |
+
descriptors.append(descs)
|
608 |
+
|
609 |
+
return descriptors, offsets
|
610 |
+
|
611 |
+
|
612 |
+
class ALIKED(Extractor):
|
613 |
+
default_conf = {
|
614 |
+
"model_name": "aliked-n16",
|
615 |
+
"max_num_keypoints": -1,
|
616 |
+
"detection_threshold": 0.2,
|
617 |
+
"nms_radius": 2,
|
618 |
+
}
|
619 |
+
|
620 |
+
checkpoint_url = "https://github.com/Shiaoming/ALIKED/raw/main/models/{}.pth"
|
621 |
+
|
622 |
+
n_limit_max = 20000
|
623 |
+
|
624 |
+
# c1, c2, c3, c4, dim, K, M
|
625 |
+
cfgs = {
|
626 |
+
"aliked-t16": [8, 16, 32, 64, 64, 3, 16],
|
627 |
+
"aliked-n16": [16, 32, 64, 128, 128, 3, 16],
|
628 |
+
"aliked-n16rot": [16, 32, 64, 128, 128, 3, 16],
|
629 |
+
"aliked-n32": [16, 32, 64, 128, 128, 3, 32],
|
630 |
+
}
|
631 |
+
preprocess_conf = {
|
632 |
+
"resize": 1024,
|
633 |
+
}
|
634 |
+
|
635 |
+
required_data_keys = ["image"]
|
636 |
+
|
637 |
+
def __init__(self, **conf):
|
638 |
+
super().__init__(**conf) # Update with default configuration.
|
639 |
+
conf = self.conf
|
640 |
+
c1, c2, c3, c4, dim, K, M = self.cfgs[conf.model_name]
|
641 |
+
conv_types = ["conv", "conv", "dcn", "dcn"]
|
642 |
+
conv2D = False
|
643 |
+
mask = False
|
644 |
+
|
645 |
+
# build model
|
646 |
+
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
|
647 |
+
self.pool4 = nn.AvgPool2d(kernel_size=4, stride=4)
|
648 |
+
self.norm = nn.BatchNorm2d
|
649 |
+
self.gate = nn.SELU(inplace=True)
|
650 |
+
self.block1 = ConvBlock(3, c1, self.gate, self.norm, conv_type=conv_types[0])
|
651 |
+
self.block2 = self.get_resblock(c1, c2, conv_types[1], mask)
|
652 |
+
self.block3 = self.get_resblock(c2, c3, conv_types[2], mask)
|
653 |
+
self.block4 = self.get_resblock(c3, c4, conv_types[3], mask)
|
654 |
+
|
655 |
+
self.conv1 = resnet.conv1x1(c1, dim // 4)
|
656 |
+
self.conv2 = resnet.conv1x1(c2, dim // 4)
|
657 |
+
self.conv3 = resnet.conv1x1(c3, dim // 4)
|
658 |
+
self.conv4 = resnet.conv1x1(dim, dim // 4)
|
659 |
+
self.upsample2 = nn.Upsample(
|
660 |
+
scale_factor=2, mode="bilinear", align_corners=True
|
661 |
+
)
|
662 |
+
self.upsample4 = nn.Upsample(
|
663 |
+
scale_factor=4, mode="bilinear", align_corners=True
|
664 |
+
)
|
665 |
+
self.upsample8 = nn.Upsample(
|
666 |
+
scale_factor=8, mode="bilinear", align_corners=True
|
667 |
+
)
|
668 |
+
self.upsample32 = nn.Upsample(
|
669 |
+
scale_factor=32, mode="bilinear", align_corners=True
|
670 |
+
)
|
671 |
+
self.score_head = nn.Sequential(
|
672 |
+
resnet.conv1x1(dim, 8),
|
673 |
+
self.gate,
|
674 |
+
resnet.conv3x3(8, 4),
|
675 |
+
self.gate,
|
676 |
+
resnet.conv3x3(4, 4),
|
677 |
+
self.gate,
|
678 |
+
resnet.conv3x3(4, 1),
|
679 |
+
)
|
680 |
+
self.desc_head = SDDH(dim, K, M, gate=self.gate, conv2D=conv2D, mask=mask)
|
681 |
+
self.dkd = DKD(
|
682 |
+
radius=conf.nms_radius,
|
683 |
+
top_k=-1 if conf.detection_threshold > 0 else conf.max_num_keypoints,
|
684 |
+
scores_th=conf.detection_threshold,
|
685 |
+
n_limit=conf.max_num_keypoints
|
686 |
+
if conf.max_num_keypoints > 0
|
687 |
+
else self.n_limit_max,
|
688 |
+
)
|
689 |
+
|
690 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
691 |
+
self.checkpoint_url.format(conf.model_name), map_location="cpu"
|
692 |
+
)
|
693 |
+
self.load_state_dict(state_dict, strict=True)
|
694 |
+
|
695 |
+
def get_resblock(self, c_in, c_out, conv_type, mask):
|
696 |
+
return ResBlock(
|
697 |
+
c_in,
|
698 |
+
c_out,
|
699 |
+
1,
|
700 |
+
nn.Conv2d(c_in, c_out, 1),
|
701 |
+
gate=self.gate,
|
702 |
+
norm_layer=self.norm,
|
703 |
+
conv_type=conv_type,
|
704 |
+
mask=mask,
|
705 |
+
)
|
706 |
+
|
707 |
+
def extract_dense_map(self, image):
|
708 |
+
# Pads images such that dimensions are divisible by
|
709 |
+
div_by = 2**5
|
710 |
+
padder = InputPadder(image.shape[-2], image.shape[-1], div_by)
|
711 |
+
image = padder.pad(image)
|
712 |
+
|
713 |
+
# ================================== feature encoder
|
714 |
+
x1 = self.block1(image) # B x c1 x H x W
|
715 |
+
x2 = self.pool2(x1)
|
716 |
+
x2 = self.block2(x2) # B x c2 x H/2 x W/2
|
717 |
+
x3 = self.pool4(x2)
|
718 |
+
x3 = self.block3(x3) # B x c3 x H/8 x W/8
|
719 |
+
x4 = self.pool4(x3)
|
720 |
+
x4 = self.block4(x4) # B x dim x H/32 x W/32
|
721 |
+
# ================================== feature aggregation
|
722 |
+
x1 = self.gate(self.conv1(x1)) # B x dim//4 x H x W
|
723 |
+
x2 = self.gate(self.conv2(x2)) # B x dim//4 x H//2 x W//2
|
724 |
+
x3 = self.gate(self.conv3(x3)) # B x dim//4 x H//8 x W//8
|
725 |
+
x4 = self.gate(self.conv4(x4)) # B x dim//4 x H//32 x W//32
|
726 |
+
x2_up = self.upsample2(x2) # B x dim//4 x H x W
|
727 |
+
x3_up = self.upsample8(x3) # B x dim//4 x H x W
|
728 |
+
x4_up = self.upsample32(x4) # B x dim//4 x H x W
|
729 |
+
x1234 = torch.cat([x1, x2_up, x3_up, x4_up], dim=1)
|
730 |
+
# ================================== score head
|
731 |
+
score_map = torch.sigmoid(self.score_head(x1234))
|
732 |
+
feature_map = torch.nn.functional.normalize(x1234, p=2, dim=1)
|
733 |
+
|
734 |
+
# Unpads images
|
735 |
+
feature_map = padder.unpad(feature_map)
|
736 |
+
score_map = padder.unpad(score_map)
|
737 |
+
|
738 |
+
return feature_map, score_map
|
739 |
+
|
740 |
+
def forward(self, data: dict) -> dict:
|
741 |
+
image = data["image"]
|
742 |
+
if image.shape[1] == 1:
|
743 |
+
image = grayscale_to_rgb(image)
|
744 |
+
feature_map, score_map = self.extract_dense_map(image)
|
745 |
+
keypoints, kptscores, scoredispersitys = self.dkd(
|
746 |
+
score_map, image_size=data.get("image_size")
|
747 |
+
)
|
748 |
+
descriptors, offsets = self.desc_head(feature_map, keypoints)
|
749 |
+
|
750 |
+
_, _, h, w = image.shape
|
751 |
+
wh = torch.tensor([w - 1, h - 1], device=image.device)
|
752 |
+
# no padding required
|
753 |
+
# we can set detection_threshold=-1 and conf.max_num_keypoints > 0
|
754 |
+
return {
|
755 |
+
"keypoints": wh * (torch.stack(keypoints) + 1) / 2.0, # B x N x 2
|
756 |
+
"descriptors": torch.stack(descriptors), # B x N x D
|
757 |
+
"keypoint_scores": torch.stack(kptscores), # B x N
|
758 |
+
}
|
third_party/LightGlue/lightglue/disk.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn as nn
|
3 |
import kornia
|
4 |
-
|
5 |
-
from .utils import ImagePreprocessor
|
6 |
|
|
|
7 |
|
8 |
-
|
|
|
9 |
default_conf = {
|
10 |
"weights": "depth",
|
11 |
"max_num_keypoints": None,
|
@@ -16,7 +15,6 @@ class DISK(nn.Module):
|
|
16 |
}
|
17 |
|
18 |
preprocess_conf = {
|
19 |
-
**ImagePreprocessor.default_conf,
|
20 |
"resize": 1024,
|
21 |
"grayscale": False,
|
22 |
}
|
@@ -24,9 +22,7 @@ class DISK(nn.Module):
|
|
24 |
required_data_keys = ["image"]
|
25 |
|
26 |
def __init__(self, **conf) -> None:
|
27 |
-
super().__init__()
|
28 |
-
self.conf = {**self.default_conf, **conf}
|
29 |
-
self.conf = SimpleNamespace(**self.conf)
|
30 |
self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
|
31 |
|
32 |
def forward(self, data: dict) -> dict:
|
@@ -34,6 +30,8 @@ class DISK(nn.Module):
|
|
34 |
for key in self.required_data_keys:
|
35 |
assert key in data, f"Missing key {key} in data"
|
36 |
image = data["image"]
|
|
|
|
|
37 |
features = self.model(
|
38 |
image,
|
39 |
n=self.conf.max_num_keypoints,
|
@@ -51,19 +49,7 @@ class DISK(nn.Module):
|
|
51 |
descriptors = torch.stack(descriptors, 0)
|
52 |
|
53 |
return {
|
54 |
-
"keypoints": keypoints.to(image),
|
55 |
-
"keypoint_scores": scores.to(image),
|
56 |
-
"descriptors": descriptors.to(image),
|
57 |
}
|
58 |
-
|
59 |
-
def extract(self, img: torch.Tensor, **conf) -> dict:
|
60 |
-
"""Perform extraction with online resizing"""
|
61 |
-
if img.dim() == 3:
|
62 |
-
img = img[None] # add batch dim
|
63 |
-
assert img.dim() == 4 and img.shape[0] == 1
|
64 |
-
shape = img.shape[-2:][::-1]
|
65 |
-
img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
|
66 |
-
feats = self.forward({"image": img})
|
67 |
-
feats["image_size"] = torch.tensor(shape)[None].to(img).float()
|
68 |
-
feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
|
69 |
-
return feats
|
|
|
|
|
|
|
1 |
import kornia
|
2 |
+
import torch
|
|
|
3 |
|
4 |
+
from .utils import Extractor
|
5 |
|
6 |
+
|
7 |
+
class DISK(Extractor):
|
8 |
default_conf = {
|
9 |
"weights": "depth",
|
10 |
"max_num_keypoints": None,
|
|
|
15 |
}
|
16 |
|
17 |
preprocess_conf = {
|
|
|
18 |
"resize": 1024,
|
19 |
"grayscale": False,
|
20 |
}
|
|
|
22 |
required_data_keys = ["image"]
|
23 |
|
24 |
def __init__(self, **conf) -> None:
|
25 |
+
super().__init__(**conf) # Update with default configuration.
|
|
|
|
|
26 |
self.model = kornia.feature.DISK.from_pretrained(self.conf.weights)
|
27 |
|
28 |
def forward(self, data: dict) -> dict:
|
|
|
30 |
for key in self.required_data_keys:
|
31 |
assert key in data, f"Missing key {key} in data"
|
32 |
image = data["image"]
|
33 |
+
if image.shape[1] == 1:
|
34 |
+
image = kornia.color.grayscale_to_rgb(image)
|
35 |
features = self.model(
|
36 |
image,
|
37 |
n=self.conf.max_num_keypoints,
|
|
|
49 |
descriptors = torch.stack(descriptors, 0)
|
50 |
|
51 |
return {
|
52 |
+
"keypoints": keypoints.to(image).contiguous(),
|
53 |
+
"keypoint_scores": scores.to(image).contiguous(),
|
54 |
+
"descriptors": descriptors.to(image).contiguous(),
|
55 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
third_party/LightGlue/lightglue/dog_hardnet.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from kornia.color import rgb_to_grayscale
|
3 |
+
from kornia.feature import HardNet, LAFDescriptor, laf_from_center_scale_ori
|
4 |
+
|
5 |
+
from .sift import SIFT
|
6 |
+
|
7 |
+
|
8 |
+
class DoGHardNet(SIFT):
|
9 |
+
required_data_keys = ["image"]
|
10 |
+
|
11 |
+
def __init__(self, **conf):
|
12 |
+
super().__init__(**conf)
|
13 |
+
self.laf_desc = LAFDescriptor(HardNet(True)).eval()
|
14 |
+
|
15 |
+
def forward(self, data: dict) -> dict:
|
16 |
+
image = data["image"]
|
17 |
+
if image.shape[1] == 3:
|
18 |
+
image = rgb_to_grayscale(image)
|
19 |
+
device = image.device
|
20 |
+
self.laf_desc = self.laf_desc.to(device)
|
21 |
+
self.laf_desc.descriptor = self.laf_desc.descriptor.eval()
|
22 |
+
pred = []
|
23 |
+
if "image_size" in data.keys():
|
24 |
+
im_size = data.get("image_size").long()
|
25 |
+
else:
|
26 |
+
im_size = None
|
27 |
+
for k in range(len(image)):
|
28 |
+
img = image[k]
|
29 |
+
if im_size is not None:
|
30 |
+
w, h = data["image_size"][k]
|
31 |
+
img = img[:, : h.to(torch.int32), : w.to(torch.int32)]
|
32 |
+
p = self.extract_single_image(img)
|
33 |
+
lafs = laf_from_center_scale_ori(
|
34 |
+
p["keypoints"].reshape(1, -1, 2),
|
35 |
+
6.0 * p["scales"].reshape(1, -1, 1, 1),
|
36 |
+
torch.rad2deg(p["oris"]).reshape(1, -1, 1),
|
37 |
+
).to(device)
|
38 |
+
p["descriptors"] = self.laf_desc(img[None], lafs).reshape(-1, 128)
|
39 |
+
pred.append(p)
|
40 |
+
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
|
41 |
+
return pred
|
third_party/LightGlue/lightglue/lightglue.py
CHANGED
@@ -1,11 +1,12 @@
|
|
|
|
1 |
from pathlib import Path
|
2 |
from types import SimpleNamespace
|
3 |
-
import
|
|
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
-
from torch import nn
|
7 |
import torch.nn.functional as F
|
8 |
-
from
|
9 |
|
10 |
try:
|
11 |
from flash_attn.modules.mha import FlashCrossAttention
|
@@ -21,15 +22,32 @@ torch.backends.cudnn.deterministic = True
|
|
21 |
|
22 |
|
23 |
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
24 |
-
def normalize_keypoints(
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
30 |
return kpts
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
34 |
x = x.unflatten(-1, (-1, 2))
|
35 |
x1, x2 = x.unbind(dim=-1)
|
@@ -64,8 +82,8 @@ class TokenConfidence(nn.Module):
|
|
64 |
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
65 |
"""get confidence tokens"""
|
66 |
return (
|
67 |
-
self.token(desc0.detach()
|
68 |
-
self.token(desc1.detach()
|
69 |
)
|
70 |
|
71 |
|
@@ -79,29 +97,40 @@ class Attention(nn.Module):
|
|
79 |
stacklevel=2,
|
80 |
)
|
81 |
self.enable_flash = allow_flash and FLASH_AVAILABLE
|
|
|
82 |
if allow_flash and FlashCrossAttention:
|
83 |
self.flash_ = FlashCrossAttention()
|
|
|
|
|
84 |
|
85 |
-
def forward(self, q, k, v) -> torch.Tensor:
|
|
|
|
|
86 |
if self.enable_flash and q.device.type == "cuda":
|
87 |
-
|
88 |
-
|
89 |
-
m = self.flash_(q.half(), torch.stack([k, v], 2).half())
|
90 |
-
return m.transpose(-2, -3).to(q.dtype)
|
91 |
-
else: # use torch 2.0 scaled_dot_product_attention with flash
|
92 |
args = [x.half().contiguous() for x in [q, k, v]]
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
96 |
args = [x.contiguous() for x in [q, k, v]]
|
97 |
-
|
|
|
98 |
else:
|
99 |
s = q.shape[-1] ** -0.5
|
100 |
-
|
|
|
|
|
|
|
101 |
return torch.einsum("...ij,...jd->...id", attn, v)
|
102 |
|
103 |
|
104 |
-
class
|
105 |
def __init__(
|
106 |
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
107 |
) -> None:
|
@@ -120,22 +149,23 @@ class Transformer(nn.Module):
|
|
120 |
nn.Linear(2 * embed_dim, embed_dim),
|
121 |
)
|
122 |
|
123 |
-
def
|
|
|
|
|
|
|
|
|
|
|
124 |
qkv = self.Wqkv(x)
|
125 |
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
|
126 |
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
context = self.inner_attn(q, k, v)
|
131 |
message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
|
132 |
return x + self.ffn(torch.cat([x, message], -1))
|
133 |
|
134 |
-
def forward(self, x0, x1, encoding0=None, encoding1=None):
|
135 |
-
return self._forward(x0, encoding0), self._forward(x1, encoding1)
|
136 |
|
137 |
-
|
138 |
-
class CrossTransformer(nn.Module):
|
139 |
def __init__(
|
140 |
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
141 |
) -> None:
|
@@ -153,7 +183,6 @@ class CrossTransformer(nn.Module):
|
|
153 |
nn.GELU(),
|
154 |
nn.Linear(2 * embed_dim, embed_dim),
|
155 |
)
|
156 |
-
|
157 |
if flash and FLASH_AVAILABLE:
|
158 |
self.flash = Attention(True)
|
159 |
else:
|
@@ -162,23 +191,31 @@ class CrossTransformer(nn.Module):
|
|
162 |
def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
|
163 |
return func(x0), func(x1)
|
164 |
|
165 |
-
def forward(
|
|
|
|
|
166 |
qk0, qk1 = self.map_(self.to_qk, x0, x1)
|
167 |
v0, v1 = self.map_(self.to_v, x0, x1)
|
168 |
qk0, qk1, v0, v1 = map(
|
169 |
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
|
170 |
(qk0, qk1, v0, v1),
|
171 |
)
|
172 |
-
if self.flash is not None:
|
173 |
-
m0 = self.flash(qk0, qk1, v1)
|
174 |
-
m1 = self.flash(
|
|
|
|
|
175 |
else:
|
176 |
qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
|
177 |
-
sim = torch.einsum("
|
|
|
|
|
178 |
attn01 = F.softmax(sim, dim=-1)
|
179 |
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
|
180 |
m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
|
181 |
m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
|
|
|
|
|
182 |
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
|
183 |
m0, m1 = self.map_(self.to_out, m0, m1)
|
184 |
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
|
@@ -186,6 +223,38 @@ class CrossTransformer(nn.Module):
|
|
186 |
return x0, x1
|
187 |
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
def sigmoid_log_double_softmax(
|
190 |
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
|
191 |
) -> torch.Tensor:
|
@@ -219,29 +288,26 @@ class MatchAssignment(nn.Module):
|
|
219 |
scores = sigmoid_log_double_softmax(sim, z0, z1)
|
220 |
return scores, sim
|
221 |
|
222 |
-
def
|
223 |
-
|
224 |
-
m1 = torch.sigmoid(self.matchability(desc1)).squeeze(-1)
|
225 |
-
return m0, m1
|
226 |
|
227 |
|
228 |
def filter_matches(scores: torch.Tensor, th: float):
|
229 |
"""obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
|
230 |
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
|
231 |
m0, m1 = max0.indices, max1.indices
|
232 |
-
|
233 |
-
|
|
|
|
|
234 |
max0_exp = max0.values.exp()
|
235 |
zero = max0_exp.new_tensor(0)
|
236 |
mscores0 = torch.where(mutual0, max0_exp, zero)
|
237 |
mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
|
238 |
-
|
239 |
-
valid0 = mutual0 & (mscores0 > th)
|
240 |
-
else:
|
241 |
-
valid0 = mutual0
|
242 |
valid1 = mutual1 & valid0.gather(1, m1)
|
243 |
-
m0 = torch.where(valid0, m0,
|
244 |
-
m1 = torch.where(valid1, m1,
|
245 |
return m0, m1, mscores0, mscores1
|
246 |
|
247 |
|
@@ -250,6 +316,7 @@ class LightGlue(nn.Module):
|
|
250 |
"name": "lightglue", # just for interfacing
|
251 |
"input_dim": 256, # input descriptor dimension (autoselected from weights)
|
252 |
"descriptor_dim": 256,
|
|
|
253 |
"n_layers": 9,
|
254 |
"num_heads": 4,
|
255 |
"flash": True, # enable FlashAttention if available.
|
@@ -260,23 +327,56 @@ class LightGlue(nn.Module):
|
|
260 |
"weights": None,
|
261 |
}
|
262 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
required_data_keys = ["image0", "image1"]
|
264 |
|
265 |
version = "v0.1_arxiv"
|
266 |
url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
|
267 |
|
268 |
features = {
|
269 |
-
"superpoint":
|
270 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
}
|
272 |
|
273 |
def __init__(self, features="superpoint", **conf) -> None:
|
274 |
super().__init__()
|
275 |
-
self.conf = {**self.default_conf, **conf}
|
276 |
if features is not None:
|
277 |
-
|
278 |
-
|
279 |
-
|
|
|
|
|
|
|
|
|
280 |
|
281 |
if conf.input_dim != conf.descriptor_dim:
|
282 |
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
|
@@ -284,22 +384,30 @@ class LightGlue(nn.Module):
|
|
284 |
self.input_proj = nn.Identity()
|
285 |
|
286 |
head_dim = conf.descriptor_dim // conf.num_heads
|
287 |
-
self.posenc = LearnableFourierPositionalEncoding(
|
|
|
|
|
288 |
|
289 |
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
self.cross_attn = nn.ModuleList(
|
294 |
-
[CrossTransformer(d, h, conf.flash) for _ in range(n)]
|
295 |
)
|
|
|
296 |
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
|
297 |
self.token_confidence = nn.ModuleList(
|
298 |
[TokenConfidence(d) for _ in range(n - 1)]
|
299 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
300 |
|
|
|
301 |
if features is not None:
|
302 |
-
fname = f"{conf.weights}_{self.version
|
303 |
state_dict = torch.hub.load_state_dict_from_url(
|
304 |
self.url.format(self.version, features), file_name=fname
|
305 |
)
|
@@ -308,9 +416,35 @@ class LightGlue(nn.Module):
|
|
308 |
path = Path(__file__).parent
|
309 |
path = path / "weights/{}.pth".format(self.conf.weights)
|
310 |
state_dict = torch.load(str(path), map_location="cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
self.load_state_dict(state_dict, strict=False)
|
312 |
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
|
315 |
def forward(self, data: dict) -> dict:
|
316 |
"""
|
@@ -326,12 +460,15 @@ class LightGlue(nn.Module):
|
|
326 |
descriptors: [B x N x D]
|
327 |
image: [B x C x H x W] or image_size: [B x 2]
|
328 |
Output (dict):
|
329 |
-
log_assignment: [B x M+1 x N+1]
|
330 |
matches0: [B x M]
|
331 |
matching_scores0: [B x M]
|
332 |
matches1: [B x N]
|
333 |
matching_scores1: [B x N]
|
334 |
-
matches: List[[Si x 2]]
|
|
|
|
|
|
|
|
|
335 |
"""
|
336 |
with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
|
337 |
return self._forward(data)
|
@@ -340,20 +477,23 @@ class LightGlue(nn.Module):
|
|
340 |
for key in self.required_data_keys:
|
341 |
assert key in data, f"Missing key {key} in data"
|
342 |
data0, data1 = data["image0"], data["image1"]
|
343 |
-
|
344 |
-
b, m, _ =
|
345 |
-
b, n, _ =
|
|
|
346 |
size0, size1 = data0.get("image_size"), data1.get("image_size")
|
347 |
-
|
348 |
-
|
349 |
-
kpts0 = normalize_keypoints(kpts0_, size=size0)
|
350 |
-
kpts1 = normalize_keypoints(kpts1_, size=size1)
|
351 |
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
|
|
|
|
|
|
|
|
357 |
|
358 |
assert desc0.shape[-1] == self.conf.input_dim
|
359 |
assert desc1.shape[-1] == self.conf.input_dim
|
@@ -362,109 +502,154 @@ class LightGlue(nn.Module):
|
|
362 |
desc0 = desc0.half()
|
363 |
desc1 = desc1.half()
|
364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
365 |
desc0 = self.input_proj(desc0)
|
366 |
desc1 = self.input_proj(desc1)
|
367 |
-
|
368 |
# cache positional embeddings
|
369 |
encoding0 = self.posenc(kpts0)
|
370 |
encoding1 = self.posenc(kpts1)
|
371 |
|
372 |
# GNN + final_proj + assignment
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
|
|
|
|
|
|
|
|
378 |
token0, token1 = None, None
|
379 |
for i in range(self.conf.n_layers):
|
380 |
-
#
|
381 |
-
|
382 |
-
desc0, desc1 = self.
|
|
|
|
|
383 |
if i == self.conf.n_layers - 1:
|
384 |
continue # no early stopping or adaptive width at last layer
|
385 |
-
|
|
|
386 |
token0, token1 = self.token_confidence[i](desc0, desc1)
|
387 |
-
if self.
|
388 |
-
break
|
389 |
-
if wic > 0: # point pruning
|
390 |
-
match0, match1 = self.log_assignment[i].scores(desc0, desc1)
|
391 |
-
mask0 = self.get_mask(token0, match0, self.conf_th(i), 1 - wic)
|
392 |
-
mask1 = self.get_mask(token1, match1, self.conf_th(i), 1 - wic)
|
393 |
-
ind0, ind1 = ind0[mask0][None], ind1[mask1][None]
|
394 |
-
desc0, desc1 = desc0[mask0][None], desc1[mask1][None]
|
395 |
-
if desc0.shape[-2] == 0 or desc1.shape[-2] == 0:
|
396 |
break
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
414 |
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
|
415 |
-
|
416 |
matches, mscores = [], []
|
417 |
for k in range(b):
|
418 |
valid = m0[k] > -1
|
419 |
-
|
|
|
|
|
|
|
|
|
|
|
420 |
mscores.append(mscores0[k][valid])
|
421 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
422 |
return {
|
423 |
-
"log_assignment": scores,
|
424 |
"matches0": m0,
|
425 |
"matches1": m1,
|
426 |
"matching_scores0": mscores0,
|
427 |
"matching_scores1": mscores1,
|
428 |
"stop": i + 1,
|
429 |
-
"prune0": prune0,
|
430 |
-
"prune1": prune1,
|
431 |
"matches": matches,
|
432 |
"scores": mscores,
|
|
|
|
|
433 |
}
|
434 |
|
435 |
-
def
|
436 |
"""scaled confidence threshold"""
|
437 |
-
|
|
|
438 |
|
439 |
-
def
|
440 |
-
self,
|
441 |
-
confidence: torch.Tensor,
|
442 |
-
match: torch.Tensor,
|
443 |
-
conf_th: float,
|
444 |
-
match_th: float,
|
445 |
) -> torch.Tensor:
|
446 |
"""mask points which should be removed"""
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
)
|
452 |
-
else:
|
453 |
-
mask = match > match_th
|
454 |
-
return mask
|
455 |
|
456 |
-
def
|
457 |
self,
|
458 |
-
|
459 |
-
|
460 |
-
|
461 |
-
|
462 |
-
seql: int,
|
463 |
) -> torch.Tensor:
|
464 |
"""evaluate stopping condition"""
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
|
|
|
|
|
|
|
|
469 |
else:
|
470 |
-
return
|
|
|
1 |
+
import warnings
|
2 |
from pathlib import Path
|
3 |
from types import SimpleNamespace
|
4 |
+
from typing import Callable, List, Optional, Tuple
|
5 |
+
|
6 |
import numpy as np
|
7 |
import torch
|
|
|
8 |
import torch.nn.functional as F
|
9 |
+
from torch import nn
|
10 |
|
11 |
try:
|
12 |
from flash_attn.modules.mha import FlashCrossAttention
|
|
|
22 |
|
23 |
|
24 |
@torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
|
25 |
+
def normalize_keypoints(
|
26 |
+
kpts: torch.Tensor, size: Optional[torch.Tensor] = None
|
27 |
+
) -> torch.Tensor:
|
28 |
+
if size is None:
|
29 |
+
size = 1 + kpts.max(-2).values - kpts.min(-2).values
|
30 |
+
elif not isinstance(size, torch.Tensor):
|
31 |
+
size = torch.tensor(size, device=kpts.device, dtype=kpts.dtype)
|
32 |
+
size = size.to(kpts)
|
33 |
+
shift = size / 2
|
34 |
+
scale = size.max(-1).values / 2
|
35 |
+
kpts = (kpts - shift[..., None, :]) / scale[..., None, None]
|
36 |
return kpts
|
37 |
|
38 |
|
39 |
+
def pad_to_length(x: torch.Tensor, length: int) -> Tuple[torch.Tensor]:
|
40 |
+
if length <= x.shape[-2]:
|
41 |
+
return x, torch.ones_like(x[..., :1], dtype=torch.bool)
|
42 |
+
pad = torch.ones(
|
43 |
+
*x.shape[:-2], length - x.shape[-2], x.shape[-1], device=x.device, dtype=x.dtype
|
44 |
+
)
|
45 |
+
y = torch.cat([x, pad], dim=-2)
|
46 |
+
mask = torch.zeros(*y.shape[:-1], 1, dtype=torch.bool, device=x.device)
|
47 |
+
mask[..., : x.shape[-2], :] = True
|
48 |
+
return y, mask
|
49 |
+
|
50 |
+
|
51 |
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
52 |
x = x.unflatten(-1, (-1, 2))
|
53 |
x1, x2 = x.unbind(dim=-1)
|
|
|
82 |
def forward(self, desc0: torch.Tensor, desc1: torch.Tensor):
|
83 |
"""get confidence tokens"""
|
84 |
return (
|
85 |
+
self.token(desc0.detach()).squeeze(-1),
|
86 |
+
self.token(desc1.detach()).squeeze(-1),
|
87 |
)
|
88 |
|
89 |
|
|
|
97 |
stacklevel=2,
|
98 |
)
|
99 |
self.enable_flash = allow_flash and FLASH_AVAILABLE
|
100 |
+
self.has_sdp = hasattr(F, "scaled_dot_product_attention")
|
101 |
if allow_flash and FlashCrossAttention:
|
102 |
self.flash_ = FlashCrossAttention()
|
103 |
+
if self.has_sdp:
|
104 |
+
torch.backends.cuda.enable_flash_sdp(allow_flash)
|
105 |
|
106 |
+
def forward(self, q, k, v, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
107 |
+
if q.shape[-2] == 0 or k.shape[-2] == 0:
|
108 |
+
return q.new_zeros((*q.shape[:-1], v.shape[-1]))
|
109 |
if self.enable_flash and q.device.type == "cuda":
|
110 |
+
# use torch 2.0 scaled_dot_product_attention with flash
|
111 |
+
if self.has_sdp:
|
|
|
|
|
|
|
112 |
args = [x.half().contiguous() for x in [q, k, v]]
|
113 |
+
v = F.scaled_dot_product_attention(*args, attn_mask=mask).to(q.dtype)
|
114 |
+
return v if mask is None else v.nan_to_num()
|
115 |
+
else:
|
116 |
+
assert mask is None
|
117 |
+
q, k, v = [x.transpose(-2, -3).contiguous() for x in [q, k, v]]
|
118 |
+
m = self.flash_(q.half(), torch.stack([k, v], 2).half())
|
119 |
+
return m.transpose(-2, -3).to(q.dtype).clone()
|
120 |
+
elif self.has_sdp:
|
121 |
args = [x.contiguous() for x in [q, k, v]]
|
122 |
+
v = F.scaled_dot_product_attention(*args, attn_mask=mask)
|
123 |
+
return v if mask is None else v.nan_to_num()
|
124 |
else:
|
125 |
s = q.shape[-1] ** -0.5
|
126 |
+
sim = torch.einsum("...id,...jd->...ij", q, k) * s
|
127 |
+
if mask is not None:
|
128 |
+
sim.masked_fill(~mask, -float("inf"))
|
129 |
+
attn = F.softmax(sim, -1)
|
130 |
return torch.einsum("...ij,...jd->...id", attn, v)
|
131 |
|
132 |
|
133 |
+
class SelfBlock(nn.Module):
|
134 |
def __init__(
|
135 |
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
136 |
) -> None:
|
|
|
149 |
nn.Linear(2 * embed_dim, embed_dim),
|
150 |
)
|
151 |
|
152 |
+
def forward(
|
153 |
+
self,
|
154 |
+
x: torch.Tensor,
|
155 |
+
encoding: torch.Tensor,
|
156 |
+
mask: Optional[torch.Tensor] = None,
|
157 |
+
) -> torch.Tensor:
|
158 |
qkv = self.Wqkv(x)
|
159 |
qkv = qkv.unflatten(-1, (self.num_heads, -1, 3)).transpose(1, 2)
|
160 |
q, k, v = qkv[..., 0], qkv[..., 1], qkv[..., 2]
|
161 |
+
q = apply_cached_rotary_emb(encoding, q)
|
162 |
+
k = apply_cached_rotary_emb(encoding, k)
|
163 |
+
context = self.inner_attn(q, k, v, mask=mask)
|
|
|
164 |
message = self.out_proj(context.transpose(1, 2).flatten(start_dim=-2))
|
165 |
return x + self.ffn(torch.cat([x, message], -1))
|
166 |
|
|
|
|
|
167 |
|
168 |
+
class CrossBlock(nn.Module):
|
|
|
169 |
def __init__(
|
170 |
self, embed_dim: int, num_heads: int, flash: bool = False, bias: bool = True
|
171 |
) -> None:
|
|
|
183 |
nn.GELU(),
|
184 |
nn.Linear(2 * embed_dim, embed_dim),
|
185 |
)
|
|
|
186 |
if flash and FLASH_AVAILABLE:
|
187 |
self.flash = Attention(True)
|
188 |
else:
|
|
|
191 |
def map_(self, func: Callable, x0: torch.Tensor, x1: torch.Tensor):
|
192 |
return func(x0), func(x1)
|
193 |
|
194 |
+
def forward(
|
195 |
+
self, x0: torch.Tensor, x1: torch.Tensor, mask: Optional[torch.Tensor] = None
|
196 |
+
) -> List[torch.Tensor]:
|
197 |
qk0, qk1 = self.map_(self.to_qk, x0, x1)
|
198 |
v0, v1 = self.map_(self.to_v, x0, x1)
|
199 |
qk0, qk1, v0, v1 = map(
|
200 |
lambda t: t.unflatten(-1, (self.heads, -1)).transpose(1, 2),
|
201 |
(qk0, qk1, v0, v1),
|
202 |
)
|
203 |
+
if self.flash is not None and qk0.device.type == "cuda":
|
204 |
+
m0 = self.flash(qk0, qk1, v1, mask)
|
205 |
+
m1 = self.flash(
|
206 |
+
qk1, qk0, v0, mask.transpose(-1, -2) if mask is not None else None
|
207 |
+
)
|
208 |
else:
|
209 |
qk0, qk1 = qk0 * self.scale**0.5, qk1 * self.scale**0.5
|
210 |
+
sim = torch.einsum("bhid, bhjd -> bhij", qk0, qk1)
|
211 |
+
if mask is not None:
|
212 |
+
sim = sim.masked_fill(~mask, -float("inf"))
|
213 |
attn01 = F.softmax(sim, dim=-1)
|
214 |
attn10 = F.softmax(sim.transpose(-2, -1).contiguous(), dim=-1)
|
215 |
m0 = torch.einsum("bhij, bhjd -> bhid", attn01, v1)
|
216 |
m1 = torch.einsum("bhji, bhjd -> bhid", attn10.transpose(-2, -1), v0)
|
217 |
+
if mask is not None:
|
218 |
+
m0, m1 = m0.nan_to_num(), m1.nan_to_num()
|
219 |
m0, m1 = self.map_(lambda t: t.transpose(1, 2).flatten(start_dim=-2), m0, m1)
|
220 |
m0, m1 = self.map_(self.to_out, m0, m1)
|
221 |
x0 = x0 + self.ffn(torch.cat([x0, m0], -1))
|
|
|
223 |
return x0, x1
|
224 |
|
225 |
|
226 |
+
class TransformerLayer(nn.Module):
|
227 |
+
def __init__(self, *args, **kwargs):
|
228 |
+
super().__init__()
|
229 |
+
self.self_attn = SelfBlock(*args, **kwargs)
|
230 |
+
self.cross_attn = CrossBlock(*args, **kwargs)
|
231 |
+
|
232 |
+
def forward(
|
233 |
+
self,
|
234 |
+
desc0,
|
235 |
+
desc1,
|
236 |
+
encoding0,
|
237 |
+
encoding1,
|
238 |
+
mask0: Optional[torch.Tensor] = None,
|
239 |
+
mask1: Optional[torch.Tensor] = None,
|
240 |
+
):
|
241 |
+
if mask0 is not None and mask1 is not None:
|
242 |
+
return self.masked_forward(desc0, desc1, encoding0, encoding1, mask0, mask1)
|
243 |
+
else:
|
244 |
+
desc0 = self.self_attn(desc0, encoding0)
|
245 |
+
desc1 = self.self_attn(desc1, encoding1)
|
246 |
+
return self.cross_attn(desc0, desc1)
|
247 |
+
|
248 |
+
# This part is compiled and allows padding inputs
|
249 |
+
def masked_forward(self, desc0, desc1, encoding0, encoding1, mask0, mask1):
|
250 |
+
mask = mask0 & mask1.transpose(-1, -2)
|
251 |
+
mask0 = mask0 & mask0.transpose(-1, -2)
|
252 |
+
mask1 = mask1 & mask1.transpose(-1, -2)
|
253 |
+
desc0 = self.self_attn(desc0, encoding0, mask0)
|
254 |
+
desc1 = self.self_attn(desc1, encoding1, mask1)
|
255 |
+
return self.cross_attn(desc0, desc1, mask)
|
256 |
+
|
257 |
+
|
258 |
def sigmoid_log_double_softmax(
|
259 |
sim: torch.Tensor, z0: torch.Tensor, z1: torch.Tensor
|
260 |
) -> torch.Tensor:
|
|
|
288 |
scores = sigmoid_log_double_softmax(sim, z0, z1)
|
289 |
return scores, sim
|
290 |
|
291 |
+
def get_matchability(self, desc: torch.Tensor):
|
292 |
+
return torch.sigmoid(self.matchability(desc)).squeeze(-1)
|
|
|
|
|
293 |
|
294 |
|
295 |
def filter_matches(scores: torch.Tensor, th: float):
|
296 |
"""obtain matches from a log assignment matrix [Bx M+1 x N+1]"""
|
297 |
max0, max1 = scores[:, :-1, :-1].max(2), scores[:, :-1, :-1].max(1)
|
298 |
m0, m1 = max0.indices, max1.indices
|
299 |
+
indices0 = torch.arange(m0.shape[1], device=m0.device)[None]
|
300 |
+
indices1 = torch.arange(m1.shape[1], device=m1.device)[None]
|
301 |
+
mutual0 = indices0 == m1.gather(1, m0)
|
302 |
+
mutual1 = indices1 == m0.gather(1, m1)
|
303 |
max0_exp = max0.values.exp()
|
304 |
zero = max0_exp.new_tensor(0)
|
305 |
mscores0 = torch.where(mutual0, max0_exp, zero)
|
306 |
mscores1 = torch.where(mutual1, mscores0.gather(1, m1), zero)
|
307 |
+
valid0 = mutual0 & (mscores0 > th)
|
|
|
|
|
|
|
308 |
valid1 = mutual1 & valid0.gather(1, m1)
|
309 |
+
m0 = torch.where(valid0, m0, -1)
|
310 |
+
m1 = torch.where(valid1, m1, -1)
|
311 |
return m0, m1, mscores0, mscores1
|
312 |
|
313 |
|
|
|
316 |
"name": "lightglue", # just for interfacing
|
317 |
"input_dim": 256, # input descriptor dimension (autoselected from weights)
|
318 |
"descriptor_dim": 256,
|
319 |
+
"add_scale_ori": False,
|
320 |
"n_layers": 9,
|
321 |
"num_heads": 4,
|
322 |
"flash": True, # enable FlashAttention if available.
|
|
|
327 |
"weights": None,
|
328 |
}
|
329 |
|
330 |
+
# Point pruning involves an overhead (gather).
|
331 |
+
# Therefore, we only activate it if there are enough keypoints.
|
332 |
+
pruning_keypoint_thresholds = {
|
333 |
+
"cpu": -1,
|
334 |
+
"mps": -1,
|
335 |
+
"cuda": 1024,
|
336 |
+
"flash": 1536,
|
337 |
+
}
|
338 |
+
|
339 |
required_data_keys = ["image0", "image1"]
|
340 |
|
341 |
version = "v0.1_arxiv"
|
342 |
url = "https://github.com/cvg/LightGlue/releases/download/{}/{}_lightglue.pth"
|
343 |
|
344 |
features = {
|
345 |
+
"superpoint": {
|
346 |
+
"weights": "superpoint_lightglue",
|
347 |
+
"input_dim": 256,
|
348 |
+
},
|
349 |
+
"disk": {
|
350 |
+
"weights": "disk_lightglue",
|
351 |
+
"input_dim": 128,
|
352 |
+
},
|
353 |
+
"aliked": {
|
354 |
+
"weights": "aliked_lightglue",
|
355 |
+
"input_dim": 128,
|
356 |
+
},
|
357 |
+
"sift": {
|
358 |
+
"weights": "sift_lightglue",
|
359 |
+
"input_dim": 128,
|
360 |
+
"add_scale_ori": True,
|
361 |
+
},
|
362 |
+
"doghardnet": {
|
363 |
+
"weights": "doghardnet_lightglue",
|
364 |
+
"input_dim": 128,
|
365 |
+
"add_scale_ori": True,
|
366 |
+
},
|
367 |
}
|
368 |
|
369 |
def __init__(self, features="superpoint", **conf) -> None:
|
370 |
super().__init__()
|
371 |
+
self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
|
372 |
if features is not None:
|
373 |
+
if features not in self.features:
|
374 |
+
raise ValueError(
|
375 |
+
f"Unsupported features: {features} not in "
|
376 |
+
f"{{{','.join(self.features)}}}"
|
377 |
+
)
|
378 |
+
for k, v in self.features[features].items():
|
379 |
+
setattr(conf, k, v)
|
380 |
|
381 |
if conf.input_dim != conf.descriptor_dim:
|
382 |
self.input_proj = nn.Linear(conf.input_dim, conf.descriptor_dim, bias=True)
|
|
|
384 |
self.input_proj = nn.Identity()
|
385 |
|
386 |
head_dim = conf.descriptor_dim // conf.num_heads
|
387 |
+
self.posenc = LearnableFourierPositionalEncoding(
|
388 |
+
2 + 2 * self.conf.add_scale_ori, head_dim, head_dim
|
389 |
+
)
|
390 |
|
391 |
h, n, d = conf.num_heads, conf.n_layers, conf.descriptor_dim
|
392 |
+
|
393 |
+
self.transformers = nn.ModuleList(
|
394 |
+
[TransformerLayer(d, h, conf.flash) for _ in range(n)]
|
|
|
|
|
395 |
)
|
396 |
+
|
397 |
self.log_assignment = nn.ModuleList([MatchAssignment(d) for _ in range(n)])
|
398 |
self.token_confidence = nn.ModuleList(
|
399 |
[TokenConfidence(d) for _ in range(n - 1)]
|
400 |
)
|
401 |
+
self.register_buffer(
|
402 |
+
"confidence_thresholds",
|
403 |
+
torch.Tensor(
|
404 |
+
[self.confidence_threshold(i) for i in range(self.conf.n_layers)]
|
405 |
+
),
|
406 |
+
)
|
407 |
|
408 |
+
state_dict = None
|
409 |
if features is not None:
|
410 |
+
fname = f"{conf.weights}_{self.version.replace('.', '-')}.pth"
|
411 |
state_dict = torch.hub.load_state_dict_from_url(
|
412 |
self.url.format(self.version, features), file_name=fname
|
413 |
)
|
|
|
416 |
path = Path(__file__).parent
|
417 |
path = path / "weights/{}.pth".format(self.conf.weights)
|
418 |
state_dict = torch.load(str(path), map_location="cpu")
|
419 |
+
|
420 |
+
if state_dict:
|
421 |
+
# rename old state dict entries
|
422 |
+
for i in range(self.conf.n_layers):
|
423 |
+
pattern = f"self_attn.{i}", f"transformers.{i}.self_attn"
|
424 |
+
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
425 |
+
pattern = f"cross_attn.{i}", f"transformers.{i}.cross_attn"
|
426 |
+
state_dict = {k.replace(*pattern): v for k, v in state_dict.items()}
|
427 |
self.load_state_dict(state_dict, strict=False)
|
428 |
|
429 |
+
# static lengths LightGlue is compiled for (only used with torch.compile)
|
430 |
+
self.static_lengths = None
|
431 |
+
|
432 |
+
def compile(
|
433 |
+
self, mode="reduce-overhead", static_lengths=[256, 512, 768, 1024, 1280, 1536]
|
434 |
+
):
|
435 |
+
if self.conf.width_confidence != -1:
|
436 |
+
warnings.warn(
|
437 |
+
"Point pruning is partially disabled for compiled forward.",
|
438 |
+
stacklevel=2,
|
439 |
+
)
|
440 |
+
|
441 |
+
torch._inductor.cudagraph_mark_step_begin()
|
442 |
+
for i in range(self.conf.n_layers):
|
443 |
+
self.transformers[i].masked_forward = torch.compile(
|
444 |
+
self.transformers[i].masked_forward, mode=mode, fullgraph=True
|
445 |
+
)
|
446 |
+
|
447 |
+
self.static_lengths = static_lengths
|
448 |
|
449 |
def forward(self, data: dict) -> dict:
|
450 |
"""
|
|
|
460 |
descriptors: [B x N x D]
|
461 |
image: [B x C x H x W] or image_size: [B x 2]
|
462 |
Output (dict):
|
|
|
463 |
matches0: [B x M]
|
464 |
matching_scores0: [B x M]
|
465 |
matches1: [B x N]
|
466 |
matching_scores1: [B x N]
|
467 |
+
matches: List[[Si x 2]]
|
468 |
+
scores: List[[Si]]
|
469 |
+
stop: int
|
470 |
+
prune0: [B x M]
|
471 |
+
prune1: [B x N]
|
472 |
"""
|
473 |
with torch.autocast(enabled=self.conf.mp, device_type="cuda"):
|
474 |
return self._forward(data)
|
|
|
477 |
for key in self.required_data_keys:
|
478 |
assert key in data, f"Missing key {key} in data"
|
479 |
data0, data1 = data["image0"], data["image1"]
|
480 |
+
kpts0, kpts1 = data0["keypoints"], data1["keypoints"]
|
481 |
+
b, m, _ = kpts0.shape
|
482 |
+
b, n, _ = kpts1.shape
|
483 |
+
device = kpts0.device
|
484 |
size0, size1 = data0.get("image_size"), data1.get("image_size")
|
485 |
+
kpts0 = normalize_keypoints(kpts0, size0).clone()
|
486 |
+
kpts1 = normalize_keypoints(kpts1, size1).clone()
|
|
|
|
|
487 |
|
488 |
+
if self.conf.add_scale_ori:
|
489 |
+
kpts0 = torch.cat(
|
490 |
+
[kpts0] + [data0[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
491 |
+
)
|
492 |
+
kpts1 = torch.cat(
|
493 |
+
[kpts1] + [data1[k].unsqueeze(-1) for k in ("scales", "oris")], -1
|
494 |
+
)
|
495 |
+
desc0 = data0["descriptors"].detach().contiguous()
|
496 |
+
desc1 = data1["descriptors"].detach().contiguous()
|
497 |
|
498 |
assert desc0.shape[-1] == self.conf.input_dim
|
499 |
assert desc1.shape[-1] == self.conf.input_dim
|
|
|
502 |
desc0 = desc0.half()
|
503 |
desc1 = desc1.half()
|
504 |
|
505 |
+
mask0, mask1 = None, None
|
506 |
+
c = max(m, n)
|
507 |
+
do_compile = self.static_lengths and c <= max(self.static_lengths)
|
508 |
+
if do_compile:
|
509 |
+
kn = min([k for k in self.static_lengths if k >= c])
|
510 |
+
desc0, mask0 = pad_to_length(desc0, kn)
|
511 |
+
desc1, mask1 = pad_to_length(desc1, kn)
|
512 |
+
kpts0, _ = pad_to_length(kpts0, kn)
|
513 |
+
kpts1, _ = pad_to_length(kpts1, kn)
|
514 |
desc0 = self.input_proj(desc0)
|
515 |
desc1 = self.input_proj(desc1)
|
|
|
516 |
# cache positional embeddings
|
517 |
encoding0 = self.posenc(kpts0)
|
518 |
encoding1 = self.posenc(kpts1)
|
519 |
|
520 |
# GNN + final_proj + assignment
|
521 |
+
do_early_stop = self.conf.depth_confidence > 0
|
522 |
+
do_point_pruning = self.conf.width_confidence > 0 and not do_compile
|
523 |
+
pruning_th = self.pruning_min_kpts(device)
|
524 |
+
if do_point_pruning:
|
525 |
+
ind0 = torch.arange(0, m, device=device)[None]
|
526 |
+
ind1 = torch.arange(0, n, device=device)[None]
|
527 |
+
# We store the index of the layer at which pruning is detected.
|
528 |
+
prune0 = torch.ones_like(ind0)
|
529 |
+
prune1 = torch.ones_like(ind1)
|
530 |
token0, token1 = None, None
|
531 |
for i in range(self.conf.n_layers):
|
532 |
+
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
|
533 |
+
break
|
534 |
+
desc0, desc1 = self.transformers[i](
|
535 |
+
desc0, desc1, encoding0, encoding1, mask0=mask0, mask1=mask1
|
536 |
+
)
|
537 |
if i == self.conf.n_layers - 1:
|
538 |
continue # no early stopping or adaptive width at last layer
|
539 |
+
|
540 |
+
if do_early_stop:
|
541 |
token0, token1 = self.token_confidence[i](desc0, desc1)
|
542 |
+
if self.check_if_stop(token0[..., :m], token1[..., :n], i, m + n):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
543 |
break
|
544 |
+
if do_point_pruning and desc0.shape[-2] > pruning_th:
|
545 |
+
scores0 = self.log_assignment[i].get_matchability(desc0)
|
546 |
+
prunemask0 = self.get_pruning_mask(token0, scores0, i)
|
547 |
+
keep0 = torch.where(prunemask0)[1]
|
548 |
+
ind0 = ind0.index_select(1, keep0)
|
549 |
+
desc0 = desc0.index_select(1, keep0)
|
550 |
+
encoding0 = encoding0.index_select(-2, keep0)
|
551 |
+
prune0[:, ind0] += 1
|
552 |
+
if do_point_pruning and desc1.shape[-2] > pruning_th:
|
553 |
+
scores1 = self.log_assignment[i].get_matchability(desc1)
|
554 |
+
prunemask1 = self.get_pruning_mask(token1, scores1, i)
|
555 |
+
keep1 = torch.where(prunemask1)[1]
|
556 |
+
ind1 = ind1.index_select(1, keep1)
|
557 |
+
desc1 = desc1.index_select(1, keep1)
|
558 |
+
encoding1 = encoding1.index_select(-2, keep1)
|
559 |
+
prune1[:, ind1] += 1
|
560 |
+
|
561 |
+
if desc0.shape[1] == 0 or desc1.shape[1] == 0: # no keypoints
|
562 |
+
m0 = desc0.new_full((b, m), -1, dtype=torch.long)
|
563 |
+
m1 = desc1.new_full((b, n), -1, dtype=torch.long)
|
564 |
+
mscores0 = desc0.new_zeros((b, m))
|
565 |
+
mscores1 = desc1.new_zeros((b, n))
|
566 |
+
matches = desc0.new_empty((b, 0, 2), dtype=torch.long)
|
567 |
+
mscores = desc0.new_empty((b, 0))
|
568 |
+
if not do_point_pruning:
|
569 |
+
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
|
570 |
+
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
571 |
+
return {
|
572 |
+
"matches0": m0,
|
573 |
+
"matches1": m1,
|
574 |
+
"matching_scores0": mscores0,
|
575 |
+
"matching_scores1": mscores1,
|
576 |
+
"stop": i + 1,
|
577 |
+
"matches": matches,
|
578 |
+
"scores": mscores,
|
579 |
+
"prune0": prune0,
|
580 |
+
"prune1": prune1,
|
581 |
+
}
|
582 |
+
|
583 |
+
desc0, desc1 = desc0[..., :m, :], desc1[..., :n, :] # remove padding
|
584 |
+
scores, _ = self.log_assignment[i](desc0, desc1)
|
585 |
m0, m1, mscores0, mscores1 = filter_matches(scores, self.conf.filter_threshold)
|
|
|
586 |
matches, mscores = [], []
|
587 |
for k in range(b):
|
588 |
valid = m0[k] > -1
|
589 |
+
m_indices_0 = torch.where(valid)[0]
|
590 |
+
m_indices_1 = m0[k][valid]
|
591 |
+
if do_point_pruning:
|
592 |
+
m_indices_0 = ind0[k, m_indices_0]
|
593 |
+
m_indices_1 = ind1[k, m_indices_1]
|
594 |
+
matches.append(torch.stack([m_indices_0, m_indices_1], -1))
|
595 |
mscores.append(mscores0[k][valid])
|
596 |
|
597 |
+
# TODO: Remove when hloc switches to the compact format.
|
598 |
+
if do_point_pruning:
|
599 |
+
m0_ = torch.full((b, m), -1, device=m0.device, dtype=m0.dtype)
|
600 |
+
m1_ = torch.full((b, n), -1, device=m1.device, dtype=m1.dtype)
|
601 |
+
m0_[:, ind0] = torch.where(m0 == -1, -1, ind1.gather(1, m0.clamp(min=0)))
|
602 |
+
m1_[:, ind1] = torch.where(m1 == -1, -1, ind0.gather(1, m1.clamp(min=0)))
|
603 |
+
mscores0_ = torch.zeros((b, m), device=mscores0.device)
|
604 |
+
mscores1_ = torch.zeros((b, n), device=mscores1.device)
|
605 |
+
mscores0_[:, ind0] = mscores0
|
606 |
+
mscores1_[:, ind1] = mscores1
|
607 |
+
m0, m1, mscores0, mscores1 = m0_, m1_, mscores0_, mscores1_
|
608 |
+
else:
|
609 |
+
prune0 = torch.ones_like(mscores0) * self.conf.n_layers
|
610 |
+
prune1 = torch.ones_like(mscores1) * self.conf.n_layers
|
611 |
+
|
612 |
return {
|
|
|
613 |
"matches0": m0,
|
614 |
"matches1": m1,
|
615 |
"matching_scores0": mscores0,
|
616 |
"matching_scores1": mscores1,
|
617 |
"stop": i + 1,
|
|
|
|
|
618 |
"matches": matches,
|
619 |
"scores": mscores,
|
620 |
+
"prune0": prune0,
|
621 |
+
"prune1": prune1,
|
622 |
}
|
623 |
|
624 |
+
def confidence_threshold(self, layer_index: int) -> float:
|
625 |
"""scaled confidence threshold"""
|
626 |
+
threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.conf.n_layers)
|
627 |
+
return np.clip(threshold, 0, 1)
|
628 |
|
629 |
+
def get_pruning_mask(
|
630 |
+
self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int
|
|
|
|
|
|
|
|
|
631 |
) -> torch.Tensor:
|
632 |
"""mask points which should be removed"""
|
633 |
+
keep = scores > (1 - self.conf.width_confidence)
|
634 |
+
if confidences is not None: # Low-confidence points are never pruned.
|
635 |
+
keep |= confidences <= self.confidence_thresholds[layer_index]
|
636 |
+
return keep
|
|
|
|
|
|
|
|
|
637 |
|
638 |
+
def check_if_stop(
|
639 |
self,
|
640 |
+
confidences0: torch.Tensor,
|
641 |
+
confidences1: torch.Tensor,
|
642 |
+
layer_index: int,
|
643 |
+
num_points: int,
|
|
|
644 |
) -> torch.Tensor:
|
645 |
"""evaluate stopping condition"""
|
646 |
+
confidences = torch.cat([confidences0, confidences1], -1)
|
647 |
+
threshold = self.confidence_thresholds[layer_index]
|
648 |
+
ratio_confident = 1.0 - (confidences < threshold).float().sum() / num_points
|
649 |
+
return ratio_confident > self.conf.depth_confidence
|
650 |
+
|
651 |
+
def pruning_min_kpts(self, device: torch.device):
|
652 |
+
if self.conf.flash and FLASH_AVAILABLE and device.type == "cuda":
|
653 |
+
return self.pruning_keypoint_thresholds["flash"]
|
654 |
else:
|
655 |
+
return self.pruning_keypoint_thresholds[device.type]
|
third_party/LightGlue/lightglue/sift.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from kornia.color import rgb_to_grayscale
|
7 |
+
from packaging import version
|
8 |
+
|
9 |
+
try:
|
10 |
+
import pycolmap
|
11 |
+
except ImportError:
|
12 |
+
pycolmap = None
|
13 |
+
|
14 |
+
from .utils import Extractor
|
15 |
+
|
16 |
+
|
17 |
+
def filter_dog_point(points, scales, angles, image_shape, nms_radius, scores=None):
|
18 |
+
h, w = image_shape
|
19 |
+
ij = np.round(points - 0.5).astype(int).T[::-1]
|
20 |
+
|
21 |
+
# Remove duplicate points (identical coordinates).
|
22 |
+
# Pick highest scale or score
|
23 |
+
s = scales if scores is None else scores
|
24 |
+
buffer = np.zeros((h, w))
|
25 |
+
np.maximum.at(buffer, tuple(ij), s)
|
26 |
+
keep = np.where(buffer[tuple(ij)] == s)[0]
|
27 |
+
|
28 |
+
# Pick lowest angle (arbitrary).
|
29 |
+
ij = ij[:, keep]
|
30 |
+
buffer[:] = np.inf
|
31 |
+
o_abs = np.abs(angles[keep])
|
32 |
+
np.minimum.at(buffer, tuple(ij), o_abs)
|
33 |
+
mask = buffer[tuple(ij)] == o_abs
|
34 |
+
ij = ij[:, mask]
|
35 |
+
keep = keep[mask]
|
36 |
+
|
37 |
+
if nms_radius > 0:
|
38 |
+
# Apply NMS on the remaining points
|
39 |
+
buffer[:] = 0
|
40 |
+
buffer[tuple(ij)] = s[keep] # scores or scale
|
41 |
+
|
42 |
+
local_max = torch.nn.functional.max_pool2d(
|
43 |
+
torch.from_numpy(buffer).unsqueeze(0),
|
44 |
+
kernel_size=nms_radius * 2 + 1,
|
45 |
+
stride=1,
|
46 |
+
padding=nms_radius,
|
47 |
+
).squeeze(0)
|
48 |
+
is_local_max = buffer == local_max.numpy()
|
49 |
+
keep = keep[is_local_max[tuple(ij)]]
|
50 |
+
return keep
|
51 |
+
|
52 |
+
|
53 |
+
def sift_to_rootsift(x: torch.Tensor, eps=1e-6) -> torch.Tensor:
|
54 |
+
x = torch.nn.functional.normalize(x, p=1, dim=-1, eps=eps)
|
55 |
+
x.clip_(min=eps).sqrt_()
|
56 |
+
return torch.nn.functional.normalize(x, p=2, dim=-1, eps=eps)
|
57 |
+
|
58 |
+
|
59 |
+
def run_opencv_sift(features: cv2.Feature2D, image: np.ndarray) -> np.ndarray:
|
60 |
+
"""
|
61 |
+
Detect keypoints using OpenCV Detector.
|
62 |
+
Optionally, perform description.
|
63 |
+
Args:
|
64 |
+
features: OpenCV based keypoints detector and descriptor
|
65 |
+
image: Grayscale image of uint8 data type
|
66 |
+
Returns:
|
67 |
+
keypoints: 1D array of detected cv2.KeyPoint
|
68 |
+
scores: 1D array of responses
|
69 |
+
descriptors: 1D array of descriptors
|
70 |
+
"""
|
71 |
+
detections, descriptors = features.detectAndCompute(image, None)
|
72 |
+
points = np.array([k.pt for k in detections], dtype=np.float32)
|
73 |
+
scores = np.array([k.response for k in detections], dtype=np.float32)
|
74 |
+
scales = np.array([k.size for k in detections], dtype=np.float32)
|
75 |
+
angles = np.deg2rad(np.array([k.angle for k in detections], dtype=np.float32))
|
76 |
+
return points, scores, scales, angles, descriptors
|
77 |
+
|
78 |
+
|
79 |
+
class SIFT(Extractor):
|
80 |
+
default_conf = {
|
81 |
+
"rootsift": True,
|
82 |
+
"nms_radius": 0, # None to disable filtering entirely.
|
83 |
+
"max_num_keypoints": 4096,
|
84 |
+
"backend": "opencv", # in {opencv, pycolmap, pycolmap_cpu, pycolmap_cuda}
|
85 |
+
"detection_threshold": 0.0066667, # from COLMAP
|
86 |
+
"edge_threshold": 10,
|
87 |
+
"first_octave": -1, # only used by pycolmap, the default of COLMAP
|
88 |
+
"num_octaves": 4,
|
89 |
+
}
|
90 |
+
|
91 |
+
preprocess_conf = {
|
92 |
+
"resize": 1024,
|
93 |
+
}
|
94 |
+
|
95 |
+
required_data_keys = ["image"]
|
96 |
+
|
97 |
+
def __init__(self, **conf):
|
98 |
+
super().__init__(**conf) # Update with default configuration.
|
99 |
+
backend = self.conf.backend
|
100 |
+
if backend.startswith("pycolmap"):
|
101 |
+
if pycolmap is None:
|
102 |
+
raise ImportError(
|
103 |
+
"Cannot find module pycolmap: install it with pip"
|
104 |
+
"or use backend=opencv."
|
105 |
+
)
|
106 |
+
options = {
|
107 |
+
"peak_threshold": self.conf.detection_threshold,
|
108 |
+
"edge_threshold": self.conf.edge_threshold,
|
109 |
+
"first_octave": self.conf.first_octave,
|
110 |
+
"num_octaves": self.conf.num_octaves,
|
111 |
+
"normalization": pycolmap.Normalization.L2, # L1_ROOT is buggy.
|
112 |
+
}
|
113 |
+
device = (
|
114 |
+
"auto" if backend == "pycolmap" else backend.replace("pycolmap_", "")
|
115 |
+
)
|
116 |
+
if (
|
117 |
+
backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
118 |
+
) and pycolmap.__version__ < "0.5.0":
|
119 |
+
warnings.warn(
|
120 |
+
"The pycolmap CPU SIFT is buggy in version < 0.5.0, "
|
121 |
+
"consider upgrading pycolmap or use the CUDA version.",
|
122 |
+
stacklevel=1,
|
123 |
+
)
|
124 |
+
else:
|
125 |
+
options["max_num_features"] = self.conf.max_num_keypoints
|
126 |
+
self.sift = pycolmap.Sift(options=options, device=device)
|
127 |
+
elif backend == "opencv":
|
128 |
+
self.sift = cv2.SIFT_create(
|
129 |
+
contrastThreshold=self.conf.detection_threshold,
|
130 |
+
nfeatures=self.conf.max_num_keypoints,
|
131 |
+
edgeThreshold=self.conf.edge_threshold,
|
132 |
+
nOctaveLayers=self.conf.num_octaves,
|
133 |
+
)
|
134 |
+
else:
|
135 |
+
backends = {"opencv", "pycolmap", "pycolmap_cpu", "pycolmap_cuda"}
|
136 |
+
raise ValueError(
|
137 |
+
f"Unknown backend: {backend} not in " f"{{{','.join(backends)}}}."
|
138 |
+
)
|
139 |
+
|
140 |
+
def extract_single_image(self, image: torch.Tensor):
|
141 |
+
image_np = image.cpu().numpy().squeeze(0)
|
142 |
+
|
143 |
+
if self.conf.backend.startswith("pycolmap"):
|
144 |
+
if version.parse(pycolmap.__version__) >= version.parse("0.5.0"):
|
145 |
+
detections, descriptors = self.sift.extract(image_np)
|
146 |
+
scores = None # Scores are not exposed by COLMAP anymore.
|
147 |
+
else:
|
148 |
+
detections, scores, descriptors = self.sift.extract(image_np)
|
149 |
+
keypoints = detections[:, :2] # Keep only (x, y).
|
150 |
+
scales, angles = detections[:, -2:].T
|
151 |
+
if scores is not None and (
|
152 |
+
self.conf.backend == "pycolmap_cpu" or not pycolmap.has_cuda
|
153 |
+
):
|
154 |
+
# Set the scores as a combination of abs. response and scale.
|
155 |
+
scores = np.abs(scores) * scales
|
156 |
+
elif self.conf.backend == "opencv":
|
157 |
+
# TODO: Check if opencv keypoints are already in corner convention
|
158 |
+
keypoints, scores, scales, angles, descriptors = run_opencv_sift(
|
159 |
+
self.sift, (image_np * 255.0).astype(np.uint8)
|
160 |
+
)
|
161 |
+
pred = {
|
162 |
+
"keypoints": keypoints,
|
163 |
+
"scales": scales,
|
164 |
+
"oris": angles,
|
165 |
+
"descriptors": descriptors,
|
166 |
+
}
|
167 |
+
if scores is not None:
|
168 |
+
pred["keypoint_scores"] = scores
|
169 |
+
|
170 |
+
# sometimes pycolmap returns points outside the image. We remove them
|
171 |
+
if self.conf.backend.startswith("pycolmap"):
|
172 |
+
is_inside = (
|
173 |
+
pred["keypoints"] + 0.5 < np.array([image_np.shape[-2:][::-1]])
|
174 |
+
).all(-1)
|
175 |
+
pred = {k: v[is_inside] for k, v in pred.items()}
|
176 |
+
|
177 |
+
if self.conf.nms_radius is not None:
|
178 |
+
keep = filter_dog_point(
|
179 |
+
pred["keypoints"],
|
180 |
+
pred["scales"],
|
181 |
+
pred["oris"],
|
182 |
+
image_np.shape,
|
183 |
+
self.conf.nms_radius,
|
184 |
+
scores=pred.get("keypoint_scores"),
|
185 |
+
)
|
186 |
+
pred = {k: v[keep] for k, v in pred.items()}
|
187 |
+
|
188 |
+
pred = {k: torch.from_numpy(v) for k, v in pred.items()}
|
189 |
+
if scores is not None:
|
190 |
+
# Keep the k keypoints with highest score
|
191 |
+
num_points = self.conf.max_num_keypoints
|
192 |
+
if num_points is not None and len(pred["keypoints"]) > num_points:
|
193 |
+
indices = torch.topk(pred["keypoint_scores"], num_points).indices
|
194 |
+
pred = {k: v[indices] for k, v in pred.items()}
|
195 |
+
|
196 |
+
return pred
|
197 |
+
|
198 |
+
def forward(self, data: dict) -> dict:
|
199 |
+
image = data["image"]
|
200 |
+
if image.shape[1] == 3:
|
201 |
+
image = rgb_to_grayscale(image)
|
202 |
+
device = image.device
|
203 |
+
image = image.cpu()
|
204 |
+
pred = []
|
205 |
+
for k in range(len(image)):
|
206 |
+
img = image[k]
|
207 |
+
if "image_size" in data.keys():
|
208 |
+
# avoid extracting points in padded areas
|
209 |
+
w, h = data["image_size"][k]
|
210 |
+
img = img[:, :h, :w]
|
211 |
+
p = self.extract_single_image(img)
|
212 |
+
pred.append(p)
|
213 |
+
pred = {k: torch.stack([p[k] for p in pred], 0).to(device) for k in pred[0]}
|
214 |
+
if self.conf.rootsift:
|
215 |
+
pred["descriptors"] = sift_to_rootsift(pred["descriptors"])
|
216 |
+
return pred
|
third_party/LightGlue/lightglue/superpoint.py
CHANGED
@@ -43,8 +43,10 @@
|
|
43 |
# Adapted by Remi Pautrat, Philipp Lindenberger
|
44 |
|
45 |
import torch
|
|
|
46 |
from torch import nn
|
47 |
-
|
|
|
48 |
|
49 |
|
50 |
def simple_nms(scores, nms_radius: int):
|
@@ -77,7 +79,9 @@ def sample_descriptors(keypoints, descriptors, s: int = 8):
|
|
77 |
"""Interpolate descriptors at keypoint locations"""
|
78 |
b, c, h, w = descriptors.shape
|
79 |
keypoints = keypoints - s / 2 + 0.5
|
80 |
-
keypoints /= torch.tensor(
|
|
|
|
|
81 |
keypoints
|
82 |
)[None]
|
83 |
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
@@ -91,7 +95,7 @@ def sample_descriptors(keypoints, descriptors, s: int = 8):
|
|
91 |
return descriptors
|
92 |
|
93 |
|
94 |
-
class SuperPoint(
|
95 |
"""SuperPoint Convolutional Detector and Descriptor
|
96 |
|
97 |
SuperPoint: Self-Supervised Interest Point Detection and
|
@@ -109,17 +113,13 @@ class SuperPoint(nn.Module):
|
|
109 |
}
|
110 |
|
111 |
preprocess_conf = {
|
112 |
-
**ImagePreprocessor.default_conf,
|
113 |
"resize": 1024,
|
114 |
-
"grayscale": True,
|
115 |
}
|
116 |
|
117 |
required_data_keys = ["image"]
|
118 |
|
119 |
def __init__(self, **conf):
|
120 |
-
super().__init__()
|
121 |
-
self.conf = {**self.default_conf, **conf}
|
122 |
-
|
123 |
self.relu = nn.ReLU(inplace=True)
|
124 |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
125 |
c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
|
@@ -138,26 +138,23 @@ class SuperPoint(nn.Module):
|
|
138 |
|
139 |
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
140 |
self.convDb = nn.Conv2d(
|
141 |
-
c5, self.conf
|
142 |
)
|
143 |
|
144 |
-
url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth"
|
145 |
self.load_state_dict(torch.hub.load_state_dict_from_url(url))
|
146 |
|
147 |
-
|
148 |
-
if mk is not None and mk <= 0:
|
149 |
raise ValueError("max_num_keypoints must be positive or None")
|
150 |
|
151 |
-
print("Loaded SuperPoint model")
|
152 |
-
|
153 |
def forward(self, data: dict) -> dict:
|
154 |
"""Compute keypoints, scores, descriptors for image"""
|
155 |
for key in self.required_data_keys:
|
156 |
assert key in data, f"Missing key {key} in data"
|
157 |
image = data["image"]
|
158 |
-
if image.shape[1] == 3:
|
159 |
-
|
160 |
-
|
161 |
# Shared Encoder
|
162 |
x = self.relu(self.conv1a(image))
|
163 |
x = self.relu(self.conv1b(x))
|
@@ -178,18 +175,18 @@ class SuperPoint(nn.Module):
|
|
178 |
b, _, h, w = scores.shape
|
179 |
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
|
180 |
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
|
181 |
-
scores = simple_nms(scores, self.conf
|
182 |
|
183 |
# Discard keypoints near the image borders
|
184 |
-
if self.conf
|
185 |
-
pad = self.conf
|
186 |
scores[:, :pad] = -1
|
187 |
scores[:, :, :pad] = -1
|
188 |
scores[:, -pad:] = -1
|
189 |
scores[:, :, -pad:] = -1
|
190 |
|
191 |
# Extract keypoints
|
192 |
-
best_kp = torch.where(scores > self.conf
|
193 |
scores = scores[best_kp]
|
194 |
|
195 |
# Separate into batches
|
@@ -199,11 +196,11 @@ class SuperPoint(nn.Module):
|
|
199 |
scores = [scores[best_kp[0] == i] for i in range(b)]
|
200 |
|
201 |
# Keep the k keypoints with highest score
|
202 |
-
if self.conf
|
203 |
keypoints, scores = list(
|
204 |
zip(
|
205 |
*[
|
206 |
-
top_k_keypoints(k, s, self.conf
|
207 |
for k, s in zip(keypoints, scores)
|
208 |
]
|
209 |
)
|
@@ -226,17 +223,5 @@ class SuperPoint(nn.Module):
|
|
226 |
return {
|
227 |
"keypoints": torch.stack(keypoints, 0),
|
228 |
"keypoint_scores": torch.stack(scores, 0),
|
229 |
-
"descriptors": torch.stack(descriptors, 0).transpose(-1, -2),
|
230 |
}
|
231 |
-
|
232 |
-
def extract(self, img: torch.Tensor, **conf) -> dict:
|
233 |
-
"""Perform extraction with online resizing"""
|
234 |
-
if img.dim() == 3:
|
235 |
-
img = img[None] # add batch dim
|
236 |
-
assert img.dim() == 4 and img.shape[0] == 1
|
237 |
-
shape = img.shape[-2:][::-1]
|
238 |
-
img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
|
239 |
-
feats = self.forward({"image": img})
|
240 |
-
feats["image_size"] = torch.tensor(shape)[None].to(img).float()
|
241 |
-
feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
|
242 |
-
return feats
|
|
|
43 |
# Adapted by Remi Pautrat, Philipp Lindenberger
|
44 |
|
45 |
import torch
|
46 |
+
from kornia.color import rgb_to_grayscale
|
47 |
from torch import nn
|
48 |
+
|
49 |
+
from .utils import Extractor
|
50 |
|
51 |
|
52 |
def simple_nms(scores, nms_radius: int):
|
|
|
79 |
"""Interpolate descriptors at keypoint locations"""
|
80 |
b, c, h, w = descriptors.shape
|
81 |
keypoints = keypoints - s / 2 + 0.5
|
82 |
+
keypoints /= torch.tensor(
|
83 |
+
[(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
|
84 |
+
).to(
|
85 |
keypoints
|
86 |
)[None]
|
87 |
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
|
|
95 |
return descriptors
|
96 |
|
97 |
|
98 |
+
class SuperPoint(Extractor):
|
99 |
"""SuperPoint Convolutional Detector and Descriptor
|
100 |
|
101 |
SuperPoint: Self-Supervised Interest Point Detection and
|
|
|
113 |
}
|
114 |
|
115 |
preprocess_conf = {
|
|
|
116 |
"resize": 1024,
|
|
|
117 |
}
|
118 |
|
119 |
required_data_keys = ["image"]
|
120 |
|
121 |
def __init__(self, **conf):
|
122 |
+
super().__init__(**conf) # Update with default configuration.
|
|
|
|
|
123 |
self.relu = nn.ReLU(inplace=True)
|
124 |
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
125 |
c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
|
|
|
138 |
|
139 |
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
140 |
self.convDb = nn.Conv2d(
|
141 |
+
c5, self.conf.descriptor_dim, kernel_size=1, stride=1, padding=0
|
142 |
)
|
143 |
|
144 |
+
url = "https://github.com/cvg/LightGlue/releases/download/v0.1_arxiv/superpoint_v1.pth" # noqa
|
145 |
self.load_state_dict(torch.hub.load_state_dict_from_url(url))
|
146 |
|
147 |
+
if self.conf.max_num_keypoints is not None and self.conf.max_num_keypoints <= 0:
|
|
|
148 |
raise ValueError("max_num_keypoints must be positive or None")
|
149 |
|
|
|
|
|
150 |
def forward(self, data: dict) -> dict:
|
151 |
"""Compute keypoints, scores, descriptors for image"""
|
152 |
for key in self.required_data_keys:
|
153 |
assert key in data, f"Missing key {key} in data"
|
154 |
image = data["image"]
|
155 |
+
if image.shape[1] == 3:
|
156 |
+
image = rgb_to_grayscale(image)
|
157 |
+
|
158 |
# Shared Encoder
|
159 |
x = self.relu(self.conv1a(image))
|
160 |
x = self.relu(self.conv1b(x))
|
|
|
175 |
b, _, h, w = scores.shape
|
176 |
scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
|
177 |
scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h * 8, w * 8)
|
178 |
+
scores = simple_nms(scores, self.conf.nms_radius)
|
179 |
|
180 |
# Discard keypoints near the image borders
|
181 |
+
if self.conf.remove_borders:
|
182 |
+
pad = self.conf.remove_borders
|
183 |
scores[:, :pad] = -1
|
184 |
scores[:, :, :pad] = -1
|
185 |
scores[:, -pad:] = -1
|
186 |
scores[:, :, -pad:] = -1
|
187 |
|
188 |
# Extract keypoints
|
189 |
+
best_kp = torch.where(scores > self.conf.detection_threshold)
|
190 |
scores = scores[best_kp]
|
191 |
|
192 |
# Separate into batches
|
|
|
196 |
scores = [scores[best_kp[0] == i] for i in range(b)]
|
197 |
|
198 |
# Keep the k keypoints with highest score
|
199 |
+
if self.conf.max_num_keypoints is not None:
|
200 |
keypoints, scores = list(
|
201 |
zip(
|
202 |
*[
|
203 |
+
top_k_keypoints(k, s, self.conf.max_num_keypoints)
|
204 |
for k, s in zip(keypoints, scores)
|
205 |
]
|
206 |
)
|
|
|
223 |
return {
|
224 |
"keypoints": torch.stack(keypoints, 0),
|
225 |
"keypoint_scores": torch.stack(scores, 0),
|
226 |
+
"descriptors": torch.stack(descriptors, 0).transpose(-1, -2).contiguous(),
|
227 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
third_party/LightGlue/lightglue/utils.py
CHANGED
@@ -1,11 +1,12 @@
|
|
|
|
1 |
from pathlib import Path
|
2 |
-
import
|
3 |
-
import
|
|
|
4 |
import cv2
|
|
|
5 |
import numpy as np
|
6 |
-
|
7 |
-
import collections.abc as collections
|
8 |
-
from types import SimpleNamespace
|
9 |
|
10 |
|
11 |
class ImagePreprocessor:
|
@@ -15,7 +16,6 @@ class ImagePreprocessor:
|
|
15 |
"interpolation": "bilinear",
|
16 |
"align_corners": None,
|
17 |
"antialias": True,
|
18 |
-
"grayscale": False, # convert rgb to grayscale
|
19 |
}
|
20 |
|
21 |
def __init__(self, **conf) -> None:
|
@@ -35,10 +35,6 @@ class ImagePreprocessor:
|
|
35 |
align_corners=self.conf.align_corners,
|
36 |
)
|
37 |
scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
|
38 |
-
if self.conf.grayscale and img.shape[-3] == 3:
|
39 |
-
img = kornia.color.rgb_to_grayscale(img)
|
40 |
-
elif not self.conf.grayscale and img.shape[-3] == 1:
|
41 |
-
img = kornia.color.grayscale_to_rgb(img)
|
42 |
return img, scale
|
43 |
|
44 |
|
@@ -132,6 +128,25 @@ def load_image(path: Path, resize: int = None, **kwargs) -> torch.Tensor:
|
|
132 |
return numpy_image_to_torch(image)
|
133 |
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
def match_pair(
|
136 |
extractor,
|
137 |
matcher,
|
|
|
1 |
+
import collections.abc as collections
|
2 |
from pathlib import Path
|
3 |
+
from types import SimpleNamespace
|
4 |
+
from typing import Callable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
import cv2
|
7 |
+
import kornia
|
8 |
import numpy as np
|
9 |
+
import torch
|
|
|
|
|
10 |
|
11 |
|
12 |
class ImagePreprocessor:
|
|
|
16 |
"interpolation": "bilinear",
|
17 |
"align_corners": None,
|
18 |
"antialias": True,
|
|
|
19 |
}
|
20 |
|
21 |
def __init__(self, **conf) -> None:
|
|
|
35 |
align_corners=self.conf.align_corners,
|
36 |
)
|
37 |
scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
|
|
|
|
|
|
|
|
|
38 |
return img, scale
|
39 |
|
40 |
|
|
|
128 |
return numpy_image_to_torch(image)
|
129 |
|
130 |
|
131 |
+
class Extractor(torch.nn.Module):
|
132 |
+
def __init__(self, **conf):
|
133 |
+
super().__init__()
|
134 |
+
self.conf = SimpleNamespace(**{**self.default_conf, **conf})
|
135 |
+
|
136 |
+
@torch.no_grad()
|
137 |
+
def extract(self, img: torch.Tensor, **conf) -> dict:
|
138 |
+
"""Perform extraction with online resizing"""
|
139 |
+
if img.dim() == 3:
|
140 |
+
img = img[None] # add batch dim
|
141 |
+
assert img.dim() == 4 and img.shape[0] == 1
|
142 |
+
shape = img.shape[-2:][::-1]
|
143 |
+
img, scales = ImagePreprocessor(**{**self.preprocess_conf, **conf})(img)
|
144 |
+
feats = self.forward({"image": img})
|
145 |
+
feats["image_size"] = torch.tensor(shape)[None].to(img).float()
|
146 |
+
feats["keypoints"] = (feats["keypoints"] + 0.5) / scales[None] - 0.5
|
147 |
+
return feats
|
148 |
+
|
149 |
+
|
150 |
def match_pair(
|
151 |
extractor,
|
152 |
matcher,
|
third_party/LightGlue/lightglue/viz2d.py
CHANGED
@@ -6,8 +6,8 @@
|
|
6 |
"""
|
7 |
|
8 |
import matplotlib
|
9 |
-
import matplotlib.pyplot as plt
|
10 |
import matplotlib.patheffects as path_effects
|
|
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
|
|
|
6 |
"""
|
7 |
|
8 |
import matplotlib
|
|
|
9 |
import matplotlib.patheffects as path_effects
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
import numpy as np
|
12 |
import torch
|
13 |
|
third_party/LightGlue/pyproject.toml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[project]
|
2 |
+
name = "lightglue"
|
3 |
+
description = "LightGlue: Local Feature Matching at Light Speed"
|
4 |
+
version = "0.0"
|
5 |
+
authors = [
|
6 |
+
{name = "Philipp Lindenberger"},
|
7 |
+
{name = "Paul-Edouard Sarlin"},
|
8 |
+
]
|
9 |
+
readme = "README.md"
|
10 |
+
requires-python = ">=3.6"
|
11 |
+
license = {file = "LICENSE"}
|
12 |
+
classifiers = [
|
13 |
+
"Programming Language :: Python :: 3",
|
14 |
+
"License :: OSI Approved :: Apache Software License",
|
15 |
+
"Operating System :: OS Independent",
|
16 |
+
]
|
17 |
+
urls = {Repository = "https://github.com/cvg/LightGlue/"}
|
18 |
+
dynamic = ["dependencies"]
|
19 |
+
|
20 |
+
[project.optional-dependencies]
|
21 |
+
dev = ["black==23.12.1", "flake8", "isort"]
|
22 |
+
|
23 |
+
[tool.setuptools]
|
24 |
+
packages = ["lightglue"]
|
25 |
+
|
26 |
+
[tool.setuptools.dynamic]
|
27 |
+
dependencies = {file = ["requirements.txt"]}
|
28 |
+
|
29 |
+
[tool.isort]
|
30 |
+
profile = "black"
|
third_party/LightGlue/setup.py
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
from setuptools import setup
|
3 |
-
|
4 |
-
description = ["LightGlue"]
|
5 |
-
|
6 |
-
with open(str(Path(__file__).parent / "README.md"), "r", encoding="utf-8") as f:
|
7 |
-
readme = f.read()
|
8 |
-
with open(str(Path(__file__).parent / "requirements.txt"), "r") as f:
|
9 |
-
dependencies = f.read().split("\n")
|
10 |
-
|
11 |
-
setup(
|
12 |
-
name="lightglue",
|
13 |
-
version="0.0",
|
14 |
-
packages=["lightglue"],
|
15 |
-
python_requires=">=3.6",
|
16 |
-
install_requires=dependencies,
|
17 |
-
author="Philipp Lindenberger, Paul-Edouard Sarlin",
|
18 |
-
description=description,
|
19 |
-
long_description=readme,
|
20 |
-
long_description_content_type="text/markdown",
|
21 |
-
url="https://github.com/cvg/LightGlue/",
|
22 |
-
classifiers=[
|
23 |
-
"Programming Language :: Python :: 3",
|
24 |
-
"License :: OSI Approved :: Apache Software License",
|
25 |
-
"Operating System :: OS Independent",
|
26 |
-
],
|
27 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|