SRPose / baselines /matchers.py
FrickYinn's picture
Upload 53 files
e170a8e verified
import torch
import time
from lightglue import LightGlue as LightGlue_
from lightglue import SuperPoint
from lightglue.utils import rbd
from kornia.feature import LoFTR as LoFTR_
def image_rgb2gray(image):
# in: torch.tensor - (3, H, W)
# out: (1, H, W)
image = image[0] * 0.3 + image[1] * 0.59 + image[2] * 0.11
return image[None]
class LightGlue():
def __init__(self, num_keypoints=2048, device='cuda'):
self.extractor = SuperPoint(max_num_keypoints=num_keypoints).eval().to(device) # load the extractor
self.matcher = LightGlue_(features='superpoint').eval().to(device) # load the matcher
self.device = device
@torch.no_grad()
def match(self, image0, image1):
start_time = time.time()
# image: torch.tensor - (3, H, W)
image0 = image0.to(self.device)
image1 = image1.to(self.device)
preprocess_time = time.time()
# extract local features
feats0 = self.extractor.extract(image0) # auto-resize the image, disable with resize=None
feats1 = self.extractor.extract(image1)
extract_time = time.time()
# match the features
matches01 = self.matcher({'image0': feats0, 'image1': feats1})
feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]] # remove batch dimension
matches = matches01['matches'] # indices with shape (K,2)
points0 = feats0['keypoints'][matches[..., 0]] # coordinates in image #0, shape (K,2)
points1 = feats1['keypoints'][matches[..., 1]] # coordinates in image #1, shape (K,2)
match_time = time.time()
return points0, points1, preprocess_time-start_time, extract_time-preprocess_time, match_time-extract_time
class LoFTR():
def __init__(self, pretrained='indoor', device='cuda'):
self.loftr = LoFTR_(pretrained=pretrained).eval().to(device)
self.device = device
@torch.no_grad()
def match(self, image0, image1):
start_time = time.time()
# image: torch.tensor - (3, H, W)
image0 = image_rgb2gray(image0)[None].to(self.device)
image1 = image_rgb2gray(image1)[None].to(self.device)
preprocess_time = time.time()
extract_time = time.time()
out = self.loftr({'image0': image0, 'image1': image1})
points0, points1 = out['keypoints0'], out['keypoints1']
match_time = time.time()
return points0, points1, preprocess_time-start_time, extract_time-preprocess_time, match_time-extract_time