Spaces:
Sleeping
Sleeping
import argparse | |
import numpy as np | |
import imageio | |
import torch | |
from tqdm import tqdm | |
import time | |
import scipy | |
import scipy.io | |
import scipy.misc | |
import os | |
import sys | |
from lib.model_test import D2Net | |
from lib.utils import preprocess_image | |
from lib.pyramid import process_multiscale | |
import cv2 | |
import matplotlib.pyplot as plt | |
from PIL import Image | |
from skimage.feature import match_descriptors | |
from skimage.measure import ransac | |
from skimage.transform import ProjectiveTransform, AffineTransform | |
import pydegensac | |
parser = argparse.ArgumentParser(description='Feature extraction script') | |
parser.add_argument('imgs', type=str, nargs=2) | |
parser.add_argument( | |
'--preprocessing', type=str, default='caffe', | |
help='image preprocessing (caffe or torch)' | |
) | |
parser.add_argument( | |
'--model_file', type=str, | |
help='path to the full model' | |
) | |
parser.add_argument( | |
'--no-relu', dest='use_relu', action='store_false', | |
help='remove ReLU after the dense feature extraction module' | |
) | |
parser.set_defaults(use_relu=True) | |
parser.add_argument( | |
'--sift', dest='use_sift', action='store_true', | |
help='Show sift matching as well' | |
) | |
parser.set_defaults(use_sift=False) | |
def extract(image, args, model, device): | |
if len(image.shape) == 2: | |
image = image[:, :, np.newaxis] | |
image = np.repeat(image, 3, -1) | |
input_image = preprocess_image( | |
image, | |
preprocessing=args.preprocessing | |
) | |
with torch.no_grad(): | |
keypoints, scores, descriptors = process_multiscale( | |
torch.tensor( | |
input_image[np.newaxis, :, :, :].astype(np.float32), | |
device=device | |
), | |
model, | |
scales=[1] | |
) | |
keypoints = keypoints[:, [1, 0, 2]] | |
feat = {} | |
feat['keypoints'] = keypoints | |
feat['scores'] = scores | |
feat['descriptors'] = descriptors | |
return feat | |
def rordMatching(image1, image2, feat1, feat2, matcher="BF"): | |
if(matcher == "BF"): | |
t0 = time.time() | |
bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True) | |
matches = bf.match(feat1['descriptors'], feat2['descriptors']) | |
matches = sorted(matches, key=lambda x:x.distance) | |
t1 = time.time() | |
print("Time to extract matches: ", t1-t0) | |
print("Number of raw matches:", len(matches)) | |
match1 = [m.queryIdx for m in matches] | |
match2 = [m.trainIdx for m in matches] | |
keypoints_left = feat1['keypoints'][match1, : 2] | |
keypoints_right = feat2['keypoints'][match2, : 2] | |
np.random.seed(0) | |
t0 = time.time() | |
H, inliers = pydegensac.findHomography(keypoints_left, keypoints_right, 10.0, 0.99, 10000) | |
t1 = time.time() | |
print("Time for ransac: ", t1-t0) | |
n_inliers = np.sum(inliers) | |
print('Number of inliers: %d.' % n_inliers) | |
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in keypoints_left[inliers]] | |
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in keypoints_right[inliers]] | |
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)] | |
draw_params = dict(matchColor = (0,255,0), | |
singlePointColor = (255,0,0), | |
# matchesMask = matchesMask, | |
flags = 0) | |
image3 = cv2.drawMatches(image1, inlier_keypoints_left, image2, inlier_keypoints_right, placeholder_matches, None, **draw_params) | |
plt.figure(figsize=(20, 20)) | |
plt.imshow(image3) | |
plt.axis('off') | |
plt.show() | |
def siftMatching(img1, img2): | |
img1 = np.array(cv2.cvtColor(np.array(img1), cv2.COLOR_BGR2RGB)) | |
img2 = np.array(cv2.cvtColor(np.array(img2), cv2.COLOR_BGR2RGB)) | |
# surf = cv2.xfeatures2d.SURF_create(100) | |
surf = cv2.xfeatures2d.SIFT_create() | |
kp1, des1 = surf.detectAndCompute(img1, None) | |
kp2, des2 = surf.detectAndCompute(img2, None) | |
FLANN_INDEX_KDTREE = 0 | |
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5) | |
search_params = dict(checks = 50) | |
flann = cv2.FlannBasedMatcher(index_params, search_params) | |
matches = flann.knnMatch(des1,des2,k=2) | |
good = [] | |
for m, n in matches: | |
if m.distance < 0.7*n.distance: | |
good.append(m) | |
src_pts = np.float32([ kp1[m.queryIdx].pt for m in good ]).reshape(-1, 2) | |
dst_pts = np.float32([ kp2[m.trainIdx].pt for m in good ]).reshape(-1, 2) | |
model, inliers = pydegensac.findHomography(src_pts, dst_pts, 10.0, 0.99, 10000) | |
n_inliers = np.sum(inliers) | |
print('Number of inliers: %d.' % n_inliers) | |
inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in src_pts[inliers]] | |
inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in dst_pts[inliers]] | |
placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)] | |
image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None) | |
cv2.imshow('Matches', image3) | |
cv2.waitKey(0) | |
src_pts = np.float32([ inlier_keypoints_left[m.queryIdx].pt for m in placeholder_matches ]).reshape(-1, 2) | |
dst_pts = np.float32([ inlier_keypoints_right[m.trainIdx].pt for m in placeholder_matches ]).reshape(-1, 2) | |
return src_pts, dst_pts | |
if __name__ == '__main__': | |
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda:0" if use_cuda else "cpu") | |
args = parser.parse_args() | |
model = D2Net( | |
model_file=args.model_file, | |
use_relu=args.use_relu, | |
use_cuda=use_cuda | |
) | |
image1 = np.array(Image.open(args.imgs[0])) | |
image2 = np.array(Image.open(args.imgs[1])) | |
print('--\nRoRD\n--') | |
feat1 = extract(image1, args, model, device) | |
feat2 = extract(image2, args, model, device) | |
print("Features extracted.") | |
rordMatching(image1, image2, feat1, feat2, matcher="BF") | |
if(args.use_sift): | |
print('--\nSIFT\n--') | |
siftMatching(image1, image2) | |