import sys
# small hack to import from src
sys.path.append("src")
import spaces
import gradio as gr
import time
import numpy as np
import omniglue
from omniglue import utils
HEADER = """
OmniGlue: Generalizable Feature Matching with Foundation Model Guidance
ArXiv Paper
GitHub Repository
Upload two images 🖼️ of the object and identify matches between them 🚀
"""
ABSTRACT = """
The image matching field has been witnessing a continuous emergence of novel learnable feature matching techniques, with ever-improving performance on conventional benchmarks. However, our investigation shows that despite these gains, their potential for real-world applications is restricted by their limited generalization capabilities to novel image domains. In this paper, we introduce OmniGlue, the first learnable image matcher that is designed with generalization as a core principle. OmniGlue leverages broad knowledge from a vision foundation model to guide the feature matching process, boosting generalization to domains not seen at training time. Additionally, we propose a novel keypoint position-guided attention mechanism which disentangles spatial and appearance information, leading to enhanced matching descriptors. We perform comprehensive experiments on a suite of 6 datasets with varied image domains, including scene-level, object-centric and aerial images. OmniGlue’s novel components lead to relative gains on unseen domains of 18.8% with respect to a directly comparable reference model, while also outperforming the recent LightGlue method by 10.1% relatively.
"""
@spaces.GPU
def find_matches(image0, image1):
# Load models.
print("> Loading OmniGlue (and its submodules: SuperPoint & DINOv2)...")
start = time.time()
og = omniglue.OmniGlue(
og_export="./models/og_export",
sp_export="./models/sp_v6",
dino_export="./models/dinov2_vitb14_pretrain.pth",
)
print(f"> \tTook {time.time() - start} seconds.")
# Perform inference.
print("> Finding matches...")
start = time.time()
match_kp0, match_kp1, match_confidences = og.FindMatches(image0, image1)
num_matches = match_kp0.shape[0]
print(f"> \tFound {num_matches} matches.")
print(f"> \tTook {time.time() - start} seconds.")
# Filter by confidence (0.02).
print("> Filtering matches...")
match_threshold = 0.02 # Choose any value [0.0, 1.0).
keep_idx = []
for i in range(match_kp0.shape[0]):
if match_confidences[i] > match_threshold:
keep_idx.append(i)
num_filtered_matches = len(keep_idx)
match_kp0 = match_kp0[keep_idx]
match_kp1 = match_kp1[keep_idx]
match_confidences = match_confidences[keep_idx]
print(f"> \tFound {num_filtered_matches}/{num_matches} above threshold {match_threshold}")
# Visualize.
print("> Visualizing matches...")
viz = utils.visualize_matches(
image0,
image1,
match_kp0,
match_kp1,
np.eye(num_filtered_matches),
show_keypoints=True,
highlight_unmatched=True,
title=f"{num_filtered_matches} matches",
line_width=2,
)
return viz
with gr.Blocks() as demo:
gr.Markdown(HEADER)
with gr.Accordion("Abstract (click to open)", open=False):
gr.Image("res/og_diagram.png")
gr.Markdown(ABSTRACT)
with gr.Row():
image_1 = gr.Image()
image_2 = gr.Image()
button = gr.Button(value="Find Matches")
output = gr.Image()
button.click(find_matches, [image_1, image_2], output)
gr.Examples(
examples=[
["res/demo1.jpg", "res/demo2.jpg"],
["res/tower-1.webp", "res/tower-2.jpeg"]
],
inputs=[image_1, image_2],
outputs=[output],
fn=find_matches,
cache_examples="lazy",
)
if __name__ == "__main__":
demo.launch()