|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Demo script for performing OmniGlue inference.""" |
|
|
|
import sys |
|
import time |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import omniglue |
|
from omniglue import utils |
|
from PIL import Image |
|
|
|
|
|
def main(argv) -> None: |
|
if len(argv) != 3: |
|
print("error - usage: python demo.py <img1_fp> <img2_fp>") |
|
return |
|
|
|
|
|
print("> Loading images...") |
|
image0 = np.array(Image.open(argv[1])) |
|
image1 = np.array(Image.open(argv[2])) |
|
|
|
|
|
print("> Loading OmniGlue (and its submodules: SuperPoint & DINOv2)...") |
|
start = time.time() |
|
og = omniglue.OmniGlue( |
|
og_export="./models/og_export", |
|
sp_export="./models/sp_v6", |
|
dino_export="./models/dinov2_vitb14_pretrain.pth", |
|
) |
|
print(f"> \tTook {time.time() - start} seconds.") |
|
|
|
|
|
print("> Finding matches...") |
|
start = time.time() |
|
match_kp0, match_kp1, match_confidences = og.FindMatches(image0, image1) |
|
num_matches = match_kp0.shape[0] |
|
print(f"> \tFound {num_matches} matches.") |
|
print(f"> \tTook {time.time() - start} seconds.") |
|
|
|
|
|
print("> Filtering matches...") |
|
match_threshold = 0.02 |
|
keep_idx = [] |
|
for i in range(match_kp0.shape[0]): |
|
if match_confidences[i] > match_threshold: |
|
keep_idx.append(i) |
|
num_filtered_matches = len(keep_idx) |
|
match_kp0 = match_kp0[keep_idx] |
|
match_kp1 = match_kp1[keep_idx] |
|
match_confidences = match_confidences[keep_idx] |
|
print(f"> \tFound {num_filtered_matches}/{num_matches} above threshold {match_threshold}") |
|
|
|
|
|
print("> Visualizing matches...") |
|
viz = utils.visualize_matches( |
|
image0, |
|
image1, |
|
match_kp0, |
|
match_kp1, |
|
np.eye(num_filtered_matches), |
|
show_keypoints=True, |
|
highlight_unmatched=True, |
|
title=f"{num_filtered_matches} matches", |
|
line_width=2, |
|
) |
|
plt.figure(figsize=(20, 10), dpi=100, facecolor="w", edgecolor="k") |
|
plt.axis("off") |
|
plt.imshow(viz) |
|
plt.imsave("./demo_output.png", viz) |
|
print("> \tSaved visualization to ./demo_output.png") |
|
|
|
|
|
if __name__ == "__main__": |
|
main(sys.argv) |
|
|