Spaces:
Running
Running
Realcat
commited on
Commit
•
e400e91
1
Parent(s):
0bc7901
update: omniglue
Browse files- common/utils.py +12 -7
- hloc/match_dense.py +1 -0
- hloc/matchers/duster.py +1 -2
- hloc/matchers/omniglue.py +1 -3
- hloc/utils/viz.py +5 -3
- third_party/omniglue/src/omniglue/omniglue_extract.py +8 -3
common/utils.py
CHANGED
@@ -642,7 +642,7 @@ def run_matching(
|
|
642 |
ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
|
643 |
choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
|
644 |
matcher_zoo: Dict[str, Any] = None,
|
645 |
-
use_cached_model: bool =
|
646 |
) -> Tuple[
|
647 |
np.ndarray,
|
648 |
np.ndarray,
|
@@ -696,19 +696,21 @@ def run_matching(
|
|
696 |
f"Success! Please be patient and allow for about 2-3 minutes."
|
697 |
f" Due to CPU inference, {key} is quiet slow."
|
698 |
)
|
|
|
699 |
model = matcher_zoo[key]
|
700 |
match_conf = model["matcher"]
|
701 |
# update match config
|
702 |
match_conf["model"]["match_threshold"] = match_threshold
|
703 |
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
704 |
-
t0 = time.time()
|
705 |
cache_key = "{}_{}".format(key, match_conf["model"]["name"])
|
706 |
-
matcher = model_cache.cache_model(cache_key, get_model, match_conf)
|
707 |
if use_cached_model:
|
|
|
|
|
708 |
matcher.conf["max_keypoints"] = extract_max_keypoints
|
709 |
matcher.conf["match_threshold"] = match_threshold
|
710 |
logger.info(f"Loaded cached model {cache_key}")
|
711 |
-
|
|
|
712 |
logger.info(f"Loading model using: {time.time()-t0:.3f}s")
|
713 |
t1 = time.time()
|
714 |
|
@@ -725,13 +727,16 @@ def run_matching(
|
|
725 |
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
726 |
cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
|
727 |
|
728 |
-
extractor = model_cache.cache_model(
|
729 |
-
cache_key, get_feature_model, extract_conf
|
730 |
-
)
|
731 |
if use_cached_model:
|
|
|
|
|
|
|
|
|
732 |
extractor.conf["max_keypoints"] = extract_max_keypoints
|
733 |
extractor.conf["keypoint_threshold"] = keypoint_threshold
|
734 |
logger.info(f"Loaded cached model {cache_key}")
|
|
|
|
|
735 |
|
736 |
pred0 = extract_features.extract(
|
737 |
extractor, image0, extract_conf["preprocessing"]
|
|
|
642 |
ransac_max_iter: int = DEFAULT_RANSAC_MAX_ITER,
|
643 |
choice_geometry_type: str = DEFAULT_SETTING_GEOMETRY,
|
644 |
matcher_zoo: Dict[str, Any] = None,
|
645 |
+
use_cached_model: bool = False,
|
646 |
) -> Tuple[
|
647 |
np.ndarray,
|
648 |
np.ndarray,
|
|
|
696 |
f"Success! Please be patient and allow for about 2-3 minutes."
|
697 |
f" Due to CPU inference, {key} is quiet slow."
|
698 |
)
|
699 |
+
t0 = time.time()
|
700 |
model = matcher_zoo[key]
|
701 |
match_conf = model["matcher"]
|
702 |
# update match config
|
703 |
match_conf["model"]["match_threshold"] = match_threshold
|
704 |
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
|
|
705 |
cache_key = "{}_{}".format(key, match_conf["model"]["name"])
|
|
|
706 |
if use_cached_model:
|
707 |
+
# because of the model cache, we need to update the config
|
708 |
+
matcher = model_cache.cache_model(cache_key, get_model, match_conf)
|
709 |
matcher.conf["max_keypoints"] = extract_max_keypoints
|
710 |
matcher.conf["match_threshold"] = match_threshold
|
711 |
logger.info(f"Loaded cached model {cache_key}")
|
712 |
+
else:
|
713 |
+
matcher = get_model(match_conf)
|
714 |
logger.info(f"Loading model using: {time.time()-t0:.3f}s")
|
715 |
t1 = time.time()
|
716 |
|
|
|
727 |
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
728 |
cache_key = "{}_{}".format(key, extract_conf["model"]["name"])
|
729 |
|
|
|
|
|
|
|
730 |
if use_cached_model:
|
731 |
+
extractor = model_cache.cache_model(
|
732 |
+
cache_key, get_feature_model, extract_conf
|
733 |
+
)
|
734 |
+
# because of the model cache, we need to update the config
|
735 |
extractor.conf["max_keypoints"] = extract_max_keypoints
|
736 |
extractor.conf["keypoint_threshold"] = keypoint_threshold
|
737 |
logger.info(f"Loaded cached model {cache_key}")
|
738 |
+
else:
|
739 |
+
extractor = get_feature_model(extract_conf)
|
740 |
|
741 |
pred0 = extract_features.extract(
|
742 |
extractor, image0, extract_conf["preprocessing"]
|
hloc/match_dense.py
CHANGED
@@ -216,6 +216,7 @@ confs = {
|
|
216 |
"model": {
|
217 |
"name": "omniglue",
|
218 |
"match_threshold": 0.2,
|
|
|
219 |
"features": "null",
|
220 |
},
|
221 |
"preprocessing": {
|
|
|
216 |
"model": {
|
217 |
"name": "omniglue",
|
218 |
"match_threshold": 0.2,
|
219 |
+
"max_keypoints": 2000,
|
220 |
"features": "null",
|
221 |
},
|
222 |
"preprocessing": {
|
hloc/matchers/duster.py
CHANGED
@@ -105,7 +105,7 @@ class Duster(BaseModel):
|
|
105 |
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
|
106 |
*pts3d_list
|
107 |
)
|
108 |
-
|
109 |
mkpts1 = pts2d_list[1][reciprocal_in_P2]
|
110 |
mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
|
111 |
|
@@ -114,7 +114,6 @@ class Duster(BaseModel):
|
|
114 |
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
|
115 |
mkpts0 = mkpts0[keep]
|
116 |
mkpts1 = mkpts1[keep]
|
117 |
-
breakpoint()
|
118 |
pred = {
|
119 |
"keypoints0": torch.from_numpy(mkpts0),
|
120 |
"keypoints1": torch.from_numpy(mkpts1),
|
|
|
105 |
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(
|
106 |
*pts3d_list
|
107 |
)
|
108 |
+
logger.info(f"Found {num_matches} matches")
|
109 |
mkpts1 = pts2d_list[1][reciprocal_in_P2]
|
110 |
mkpts0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]
|
111 |
|
|
|
114 |
keep = np.round(np.linspace(0, len(mkpts0) - 1, top_k)).astype(int)
|
115 |
mkpts0 = mkpts0[keep]
|
116 |
mkpts1 = mkpts1[keep]
|
|
|
117 |
pred = {
|
118 |
"keypoints0": torch.from_numpy(mkpts0),
|
119 |
"keypoints1": torch.from_numpy(mkpts1),
|
hloc/matchers/omniglue.py
CHANGED
@@ -39,7 +39,6 @@ class OmniGlue(BaseModel):
|
|
39 |
subprocess.run(cmd, check=True)
|
40 |
else:
|
41 |
logger.error(f"Invalid dinov2 model: {dino_model_path.name}")
|
42 |
-
|
43 |
self.net = omniglue.OmniGlue(
|
44 |
og_export=str(og_model_path),
|
45 |
sp_export=str(sp_model_path),
|
@@ -54,9 +53,8 @@ class OmniGlue(BaseModel):
|
|
54 |
image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255
|
55 |
image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255
|
56 |
match_kp0, match_kp1, match_confidences = self.net.FindMatches(
|
57 |
-
image0_rgb_np, image1_rgb_np
|
58 |
)
|
59 |
-
|
60 |
# filter matches
|
61 |
match_threshold = self.conf["match_threshold"]
|
62 |
keep_idx = []
|
|
|
39 |
subprocess.run(cmd, check=True)
|
40 |
else:
|
41 |
logger.error(f"Invalid dinov2 model: {dino_model_path.name}")
|
|
|
42 |
self.net = omniglue.OmniGlue(
|
43 |
og_export=str(og_model_path),
|
44 |
sp_export=str(sp_model_path),
|
|
|
53 |
image0_rgb_np = image0_rgb_np.astype(np.uint8) # RGB, 0-255
|
54 |
image1_rgb_np = image1_rgb_np.astype(np.uint8) # RGB, 0-255
|
55 |
match_kp0, match_kp1, match_confidences = self.net.FindMatches(
|
56 |
+
image0_rgb_np, image1_rgb_np, self.conf["max_keypoints"]
|
57 |
)
|
|
|
58 |
# filter matches
|
59 |
match_threshold = self.conf["match_threshold"]
|
60 |
keep_idx = []
|
hloc/utils/viz.py
CHANGED
@@ -65,9 +65,11 @@ def plot_keypoints(kpts, colors="lime", ps=4):
|
|
65 |
if not isinstance(colors, list):
|
66 |
colors = [colors] * len(kpts)
|
67 |
axes = plt.gcf().axes
|
68 |
-
|
69 |
-
a
|
70 |
-
|
|
|
|
|
71 |
|
72 |
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
|
73 |
"""Plot matches for a pair of existing images.
|
|
|
65 |
if not isinstance(colors, list):
|
66 |
colors = [colors] * len(kpts)
|
67 |
axes = plt.gcf().axes
|
68 |
+
try:
|
69 |
+
for a, k, c in zip(axes, kpts, colors):
|
70 |
+
a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
|
71 |
+
except IndexError as e:
|
72 |
+
pass
|
73 |
|
74 |
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
|
75 |
"""Plot matches for a pair of existing images.
|
third_party/omniglue/src/omniglue/omniglue_extract.py
CHANGED
@@ -46,13 +46,18 @@ class OmniGlue:
|
|
46 |
dino_export, feature_layer=1
|
47 |
)
|
48 |
|
49 |
-
def FindMatches(
|
|
|
|
|
|
|
|
|
|
|
50 |
"""TODO(omniglue): docstring."""
|
51 |
height0, width0 = image0.shape[:2]
|
52 |
height1, width1 = image1.shape[:2]
|
53 |
# TODO: numpy to torch inputs
|
54 |
-
sp_features0 = self.sp_extract(image0, num_features=
|
55 |
-
sp_features1 = self.sp_extract(image1, num_features=
|
56 |
dino_features0 = self.dino_extract(image0)
|
57 |
dino_features1 = self.dino_extract(image1)
|
58 |
dino_descriptors0 = dino_extract.get_dino_descriptors(
|
|
|
46 |
dino_export, feature_layer=1
|
47 |
)
|
48 |
|
49 |
+
def FindMatches(
|
50 |
+
self,
|
51 |
+
image0: np.ndarray,
|
52 |
+
image1: np.ndarray,
|
53 |
+
max_keypoints: int = 2048,
|
54 |
+
):
|
55 |
"""TODO(omniglue): docstring."""
|
56 |
height0, width0 = image0.shape[:2]
|
57 |
height1, width1 = image1.shape[:2]
|
58 |
# TODO: numpy to torch inputs
|
59 |
+
sp_features0 = self.sp_extract(image0, num_features=max_keypoints)
|
60 |
+
sp_features1 = self.sp_extract(image1, num_features=max_keypoints)
|
61 |
dino_features0 = self.dino_extract(image0)
|
62 |
dino_features1 = self.dino_extract(image1)
|
63 |
dino_descriptors0 = dino_extract.get_dino_descriptors(
|