SoundOfWater / shared /utils /keypoint_matching.py
bpiyush's picture
Upload folder using huggingface_hub
eafbf97 verified
raw
history blame
10.6 kB
"""Implements keypoint matching for a pair of images."""
import os
import numpy as np
import PIL
import cv2
import matplotlib.pyplot as plt
def show_single_image(img, figsize=(7, 5), title="Single image"):
"""Displays a single image."""
fig = plt.figure(figsize=figsize)
plt.axis("off")
plt.imshow(img)
plt.title(title)
plt.show()
def show_two_images(img1, img2, title="Two images"):
"""Displays a pair of images."""
fig, ax = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)
ax[0].axis("off")
ax[0].imshow(img1)
ax[1].axis("off")
ax[1].imshow(img2)
plt.suptitle(title)
plt.show()
def show_three_images(img1, img2, img3, ax1_title="", ax2_title="", ax3_title="", title="Three images"):
"""Displays a triplet of images."""
fig, ax = plt.subplots(1, 3, figsize=(15, 5), constrained_layout=True)
ax[0].axis("off")
ax[0].imshow(img1)
ax[0].set_title(ax1_title)
ax[1].axis("off")
ax[1].imshow(img2)
ax[1].set_title(ax2_title)
ax[2].axis("off")
ax[2].imshow(img3)
ax[2].set_title(ax3_title)
plt.suptitle(title)
plt.show()
class KeypointMatcher:
"""Class for Keypoint matching for a pair of images."""
def __init__(self, **sift_args) -> None:
self.SIFT = cv2.SIFT_create(**sift_args)
self.BFMatcher = cv2.BFMatcher()
@staticmethod
def _check_images(img1: np.ndarray, img2: np.ndarray):
assert isinstance(img1, np.ndarray)
assert len(img1.shape) == 2
assert isinstance(img2, np.ndarray)
assert len(img2.shape) == 2
# assert img1.shape == img2.shape
@staticmethod
def _show_matches(img1, kp1, img2, kp2, matches, K=10, figsize=(10, 5), drawMatches_args=dict(matchesThickness=3, singlePointColor=(0, 0, 0))):
"""Displays matches found in the image"""
selected_matches = np.random.choice(matches, K)
img3 = cv2.drawMatches(img1, kp1, img2, kp2, selected_matches, outImg=None, **drawMatches_args)
show_single_image(img3, figsize=figsize, title=f"Randomly selected K = {K} matches between the pair of images.")
return img3
def match(self, img1: PIL.Image, img2: PIL.Image, show_matches: bool = True):
"""Finds, describes and matches keypoints in given pair of images."""
img1 = np.array(img1)
img1 = cv2.cvtColor(img1, cv2.COLOR_RGB2GRAY)
img2 = np.array(img2)
img2 = cv2.cvtColor(img2, cv2.COLOR_RGB2GRAY)
# check input images
self._check_images(img1, img2)
# find kps and descriptors in each image
kp1, des1 = self.SIFT.detectAndCompute(img1, None)
kp2, des2 = self.SIFT.detectAndCompute(img2, None)
# compute matches via Brute-force matching
matches = self.BFMatcher.match(des1, des2)
# sort them in the order of their distance
matches = sorted(matches, key = lambda x:x.distance)
if show_matches:
self._show_matches(img1, kp1, img2, kp2, matches)
return matches, kp1, des1, kp2, des2
def warp(im, M, output_shape):
out = np.zeros((output_shape[0], output_shape[1]))
for i in range(output_shape[0]):
for j in range(output_shape[1]):
u, v = np.array([[i, j, 0, 0, 1, 0], [0, 0, i, j, 0, 1]]) @ M
u = int(round(u))
v = int(round(v))
if im.shape[0] > u >= 0 and im.shape[1] > v >= 0:
out[i, j] = im[u, v]
return out
def project_2d_to_6d(X: np.ndarray):
"""Projects X (N x 2) to Z (2N x 6) space."""
N = len(X)
assert X.shape == (N, 2)
Z = np.zeros((2 * N, 6))
# in columns 0 to 2, fill even indexed rows of Z with X, and fill 5th column with 1
Z[::2, 0:2] = X
Z[::2, 4] = 1.0
# in columns 2 to 4, fill odd indexed rows of Z with X
Z[1::2, 2:4] = X
Z[1::2, 5] = 1.0
return Z
def project_6d_to_2d(Z: np.ndarray):
"""Projects Z (2N x 6) to X (N x 2) space."""
N = len(Z) // 2
assert Z.shape == (2 * N, 6)
X_from_even_rows = Z[::2, 0:2]
X_from_odd_rows = Z[1::2, 2:4]
assert (X_from_even_rows == X_from_odd_rows).all()
return X_from_even_rows
def project_2d_to_1d(X: np.ndarray):
"""Returns X (N x 2) from Z (2N, 1)"""
N = len(X)
X_stretched = np.zeros(2 * N)
X_stretched[::2] = X[:, 0]
X_stretched[1::2] = X[:, 1]
return X_stretched
def project_1d_to_2d(Z: np.ndarray):
"""Returns X (N x 2) from Z (2N, 1)"""
N = len(Z) // 2
assert Z.shape == (2 * N,)
X = np.zeros((N, 2))
X[:, 0] = Z[::2]
X[:, 1] = Z[1::2]
return X
def rigid_body_transform(X: np.ndarray, params: np.ndarray):
"""Performs rigid body transformation of points X (N x 2) using params (6 x 1 flattened)"""
N = len(X)
assert X.shape == (N, 2)
X = project_2d_to_6d(X)
X_transformed = np.matmul(X, params)
X_transformed = project_1d_to_2d(X_transformed)
assert X_transformed.shape == (N, 2)
return X_transformed
def rigid_body_transform_params(X1: np.ndarray, X2: np.ndarray):
"""Returns rigid-body transform parameters RT (6 x 1) assuming transformation between X1 and X2"""
N = len(X1)
assert X1.shape == X2.shape
assert X1.shape == (N, 2)
# X2 = X1 * params => params = psuedoinverse(X1) * X2
X1_expanded = project_2d_to_6d(X1)
assert X1_expanded.shape == (2 * N, 6)
X2_stretched = project_2d_to_1d(X2)
assert X2_stretched.shape == (2 * N,)
params = np.dot(np.linalg.pinv(X1_expanded), X2_stretched)
return params
class ImageAlignment:
"""Class to perform alignment of a pair of images given keypoints."""
def __init__(self) -> None:
pass
@staticmethod
def show_transformed_points(img1, img2, X1, kp1, kp2, matches, params, num_inliers, num_to_show=20):
import matplotlib.cm as cm
H1, W1 = img1.shape
H2, W2 = img2.shape
img = np.hstack([img1, img2])
random_matches = np.random.choice(matches, num_to_show)
fig, ax = plt.subplots(1, 1, figsize=(15, 6))
colors = cm.rainbow(np.linspace(0, 1, num_to_show))
for i, match in enumerate(random_matches):
# select a single match to visualize
x1, y1 = kp1[match.queryIdx].pt
x2, y2 = kp2[match.trainIdx].pt
# get (x1, y1) transformed to (x1_transformed, y1_transformed)
A = project_2d_to_6d(np.array([[x1, y1]]))
(x1_transformed, y1_transformed) = np.dot(A, params)
ax.imshow(img, cmap="gray")
ax.axis("off")
ax.scatter(x1_transformed + W1, y1_transformed, s=200, marker="x", color=colors[i])
ax.plot(
(x1, x1_transformed + W1), (y1, y1_transformed),
linestyle="--", color=colors[i], marker="o",
)
ax.set_title(
f"Points in image 1 mapped to transformed points estimated by {num_inliers} points.",
fontsize=18,
)
os.makedirs("./results/", exist_ok=True)
plt.savefig(f"./results/match_transformed_inliers_{num_inliers}.png", bbox_inches="tight")
plt.show()
def ransac(
self, img1, kp1, img2, kp2, matches, num_matches=6, max_iter=500,
radius_in_px=10, show_transformed=True, inlier_th_for_show=1000
):
"""Performs RANSAC to find best matches."""
best_inlier_count = 0
best_params = None
# get coordinates of all points in image 1
X1 = np.array([kp1[matches[i].queryIdx].pt for i in range(len(matches))])
# get coordinates of all points in image 2
X2 = np.array([kp2[matches[i].trainIdx].pt for i in range(len(matches))])
for i in range(max_iter):
# choose matches randomly
selected_matches = np.random.choice(matches, num_matches)
# get matched keypoints in img1
X1_selected = np.array([kp1[selected_matches[i].queryIdx].pt for i in range(len(selected_matches))])
# get matched keypoints in img2
X2_selected = np.array([kp2[selected_matches[i].trainIdx].pt for i in range(len(selected_matches))])
# get transformation parameters
params = rigid_body_transform_params(X1_selected, X2_selected)
# transform X1 to get X2_transformed
X2_transformed = rigid_body_transform(X1, params)
# find inliers
diff = np.linalg.norm(X2_transformed - X2, axis=1)
indices = diff < radius_in_px
num_inliers = sum(indices)
if num_inliers > best_inlier_count:
print(f"Found {num_inliers} inliers!")
best_params = params
best_inlier_count = num_inliers
if show_transformed and num_inliers > inlier_th_for_show:
self.show_transformed_points(img1, img2, X1, kp1, kp2, matches, best_params, num_inliers)
return best_params
def align(
self, img1, kp1, img2, kp2, matches, num_matches=6,
max_iter=500, show_warped_image=True,
save_warped=False, path="results/sample.png",
method="custom"
):
best_params = self.ransac(img1, kp1, img2, kp2, matches, max_iter=max_iter, num_matches=num_matches)
# apply the affine transformation using cv2.warpAffine()
rows, cols = img1.shape[:2]
if method == 'custom':
img1_warped = warp(img1, best_params, (rows, cols))
else:
M = np.zeros((2, 3))
M[0, :2] = best_params[:2]
M[1, :2] = best_params[2:4]
M[0, 2] = best_params[4]
M[1, 2] = best_params[5]
img1_warped = cv2.warpAffine(img1, M, (cols, rows))
if show_warped_image:
show_three_images(
img1, img2, img1_warped, title="",
ax1_title="Image 1", ax2_title="Image 2", ax3_title="Transformation: Image 1 to Image 2",
)
if save_warped:
plt.imsave(path, img1_warped)
return best_params
if __name__ == "__main__":
# read & show images
boat1 = cv2.imread('boat1.pgm', cv2.IMREAD_GRAYSCALE)
boat2 = cv2.imread('boat2.pgm', cv2.IMREAD_GRAYSCALE)
show_two_images(boat1, boat2, title="Given pair of images.")
kp_matcher = KeypointMatcher(contrastThreshold=0.1, edgeThreshold=5)
matches, kp1, des1, kp2, des2 = kp_matcher.match(boat1, boat2, show_matches=True)