File size: 2,617 Bytes
e170a8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import time

from lightglue import LightGlue as LightGlue_
from lightglue import SuperPoint
from lightglue.utils import rbd
from kornia.feature import LoFTR as LoFTR_


def image_rgb2gray(image):
    # in: torch.tensor - (3, H, W)
    # out: (1, H, W)
    image = image[0] * 0.3 + image[1] * 0.59 + image[2] * 0.11
    return image[None]


class LightGlue():
    def __init__(self, num_keypoints=2048, device='cuda'):
        self.extractor = SuperPoint(max_num_keypoints=num_keypoints).eval().to(device)  # load the extractor
        self.matcher = LightGlue_(features='superpoint').eval().to(device)  # load the matcher
        self.device = device

    @torch.no_grad()
    def match(self, image0, image1):
        start_time = time.time()

        # image: torch.tensor - (3, H, W)
        image0 = image0.to(self.device)
        image1 = image1.to(self.device)

        preprocess_time = time.time()

        # extract local features
        feats0 = self.extractor.extract(image0)  # auto-resize the image, disable with resize=None
        feats1 = self.extractor.extract(image1)

        extract_time = time.time()

        # match the features
        matches01 = self.matcher({'image0': feats0, 'image1': feats1})
        feats0, feats1, matches01 = [rbd(x) for x in [feats0, feats1, matches01]]  # remove batch dimension
        matches = matches01['matches']  # indices with shape (K,2)
        points0 = feats0['keypoints'][matches[..., 0]]  # coordinates in image #0, shape (K,2)
        points1 = feats1['keypoints'][matches[..., 1]]  # coordinates in image #1, shape (K,2)

        match_time = time.time()

        return points0, points1, preprocess_time-start_time, extract_time-preprocess_time, match_time-extract_time


class LoFTR():
    def __init__(self, pretrained='indoor', device='cuda'):
        self.loftr = LoFTR_(pretrained=pretrained).eval().to(device)
        self.device = device

    @torch.no_grad()
    def match(self, image0, image1):
        start_time = time.time()

        # image: torch.tensor - (3, H, W)
        image0 = image_rgb2gray(image0)[None].to(self.device)
        image1 = image_rgb2gray(image1)[None].to(self.device)

        preprocess_time = time.time()

        extract_time = time.time()

        out = self.loftr({'image0': image0, 'image1': image1})
        points0, points1 = out['keypoints0'], out['keypoints1']
        
        match_time = time.time()
        return points0, points1, preprocess_time-start_time, extract_time-preprocess_time, match_time-extract_time