Spaces:
Running
on
L40S
Running
on
L40S
import sys | |
import io | |
import os | |
import cv2 | |
import math | |
import numpy as np | |
from scipy.signal import medfilt | |
from scipy.spatial import KDTree | |
from matplotlib import pyplot as plt | |
from PIL import Image | |
from dust3r.inference import inference | |
from dust3r.utils.image import load_images# , resize_images | |
from dust3r.image_pairs import make_pairs | |
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode | |
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid | |
from third_party.utils.camera_utils import remap_points | |
from third_party.utils.img_utils import rgba_to_rgb, resize_with_aspect_ratio | |
from third_party.utils.img_utils import compute_img_diff | |
from PIL.ImageOps import exif_transpose | |
import torchvision.transforms as tvf | |
ImgNorm = tvf.Compose([tvf.ToTensor(), tvf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
def suppress_output(func): | |
def wrapper(*args, **kwargs): | |
original_stdout = sys.stdout | |
original_stderr = sys.stderr | |
sys.stdout = io.StringIO() | |
sys.stderr = io.StringIO() | |
try: | |
return func(*args, **kwargs) | |
finally: | |
sys.stdout = original_stdout | |
sys.stderr = original_stderr | |
return wrapper | |
def _resize_pil_image(img, long_edge_size): | |
S = max(img.size) | |
if S > long_edge_size: | |
interp = Image.LANCZOS | |
elif S <= long_edge_size: | |
interp = Image.BICUBIC | |
new_size = tuple(int(round(x*long_edge_size/S)) for x in img.size) | |
return img.resize(new_size, interp) | |
def resize_images(imgs_list, size, square_ok=False): | |
""" open and convert all images in a list or folder to proper input format for DUSt3R | |
""" | |
imgs = [] | |
for img in imgs_list: | |
img = exif_transpose(Image.fromarray(img)).convert('RGB') | |
W1, H1 = img.size | |
if size == 224: | |
# resize short side to 224 (then crop) | |
img = _resize_pil_image(img, round(size * max(W1/H1, H1/W1))) | |
else: | |
# resize long side to 512 | |
img = _resize_pil_image(img, size) | |
W, H = img.size | |
cx, cy = W//2, H//2 | |
if size == 224: | |
half = min(cx, cy) | |
img = img.crop((cx-half, cy-half, cx+half, cy+half)) | |
else: | |
halfw, halfh = ((2*cx)//16)*8, ((2*cy)//16)*8 | |
if not (square_ok) and W == H: | |
halfh = 3*halfw/4 | |
img = img.crop((cx-halfw, cy-halfh, cx+halfw, cy+halfh)) | |
W2, H2 = img.size | |
imgs.append(dict(img=ImgNorm(img)[None], true_shape=np.int32( | |
[img.size[::-1]]), idx=len(imgs), instance=str(len(imgs)))) | |
return imgs | |
def infer_match(images, model, vis=False, niter=300, lr=0.01, schedule='cosine', device="cuda:0"): | |
batch_size = 1 | |
schedule = 'cosine' | |
lr = 0.01 | |
niter = 300 | |
images_packed = resize_images(images, size=512, square_ok=True) | |
# images_packed = images | |
pairs = make_pairs(images_packed, scene_graph='complete', prefilter=None, symmetrize=True) | |
output = inference(pairs, model, device, batch_size=batch_size, verbose=False) | |
scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PointCloudOptimizer) | |
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr) | |
# retrieve useful values from scene: | |
imgs = scene.imgs | |
# focals = scene.get_focals() | |
# poses = scene.get_im_poses() | |
pts3d = scene.get_pts3d() | |
confidence_masks = scene.get_masks() | |
# visualize reconstruction | |
# scene.show() | |
# find 2D-2D matches between the two images | |
pts2d_list, pts3d_list = [], [] | |
for i in range(2): | |
conf_i = confidence_masks[i].cpu().numpy() | |
pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i]) # imgs[i].shape[:2] = (H, W) | |
pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i]) | |
if pts3d_list[-1].shape[0] == 0: | |
return np.zeros((0, 2)), np.zeros((0, 2)) | |
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list) | |
matches_im1 = pts2d_list[1][reciprocal_in_P2] | |
matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2] | |
# visualize a few matches | |
if vis == True: | |
print(f'found {num_matches} matches') | |
n_viz = 20 | |
match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int) | |
viz_matches_im0, viz_matches_im1 = matches_im0[match_idx_to_viz], matches_im1[match_idx_to_viz] | |
H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2] | |
img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) | |
img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), 'constant', constant_values=0) | |
img = np.concatenate((img0, img1), axis=1) | |
plt.figure() | |
plt.imshow(img) | |
cmap = plt.get_cmap('jet') | |
for i in range(n_viz): | |
(x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T | |
plt.plot([x0, x1 + W0], [y0, y1], '-+', color=cmap(i / (n_viz - 1)), scalex=False, scaley=False) | |
plt.show(block=True) | |
matches_im0 = remap_points(images[0].shape, matches_im0) | |
matches_im1 = remap_points(images[1].shape, matches_im1) | |
return matches_im0, matches_im1 | |
def point_transform(H, pt): | |
""" | |
@param: H is homography matrix of dimension (3x3) | |
@param: pt is the (x, y) point to be transformed | |
Return: | |
returns a transformed point ptrans = H*pt. | |
""" | |
a = H[0, 0] * pt[0] + H[0, 1] * pt[1] + H[0, 2] | |
b = H[1, 0] * pt[0] + H[1, 1] * pt[1] + H[1, 2] | |
c = H[2, 0] * pt[0] + H[2, 1] * pt[1] + H[2, 2] | |
return [a / c, b / c] | |
def points_transform(H, pt_x, pt_y): | |
""" | |
@param: H is homography matrix of dimension (3x3) | |
@param: pt is the (x, y) point to be transformed | |
Return: | |
returns a transformed point ptrans = H*pt. | |
""" | |
a = H[0, 0] * pt_x + H[0, 1] * pt_y + H[0, 2] | |
b = H[1, 0] * pt_x + H[1, 1] * pt_y + H[1, 2] | |
c = H[2, 0] * pt_x + H[2, 1] * pt_y + H[2, 2] | |
return (a / c, b / c) | |
def motion_propagate(old_points, new_points, old_size, new_size, H_size=(21, 21)): | |
""" | |
@param: old_points are points in old_frame that are | |
matched feature points with new_frame | |
@param: new_points are points in new_frame that are | |
matched feature points with old_frame | |
@param: old_frame is the frame to which | |
motion mesh needs to be obtained | |
@param: H is the homography between old and new points | |
Return: | |
returns a motion mesh in x-direction | |
and y-direction for old_frame | |
""" | |
# spreads motion over the mesh for the old_frame | |
x_motion = np.zeros(H_size) | |
y_motion = np.zeros(H_size) | |
mesh_x_num, mesh_y_num = H_size[0], H_size[1] | |
pixels_x, pixels_y = (old_size[1]) / (mesh_x_num - 1), (old_size[0]) / (mesh_y_num - 1) | |
radius = max(pixels_x, pixels_y) * 5 | |
sigma = radius / 3.0 | |
H_global = None | |
if old_points.shape[0] > 3: | |
# pre-warping with global homography | |
H_global, _ = cv2.findHomography(old_points, new_points, cv2.RANSAC) | |
if H_global is None: | |
old_tmp = np.array([[0, 0], [0, old_size[0]], [old_size[1], 0], [old_size[1], old_size[0]]]) | |
new_tmp = np.array([[0, 0], [0, new_size[0]], [new_size[1], 0], [new_size[1], new_size[0]]]) | |
H_global, _ = cv2.findHomography(old_tmp, new_tmp, cv2.RANSAC) | |
for i in range(mesh_x_num): | |
for j in range(mesh_y_num): | |
pt = [pixels_x * i, pixels_y * j] | |
ptrans = point_transform(H_global, pt) | |
x_motion[i, j] = ptrans[0] | |
y_motion[i, j] = ptrans[1] | |
# disturbute feature motion vectors | |
weighted_move_x = np.zeros(H_size) | |
weighted_move_y = np.zeros(H_size) | |
# 构建 KDTree | |
tree = KDTree(old_points) | |
# 计算权重和移动值 | |
for i in range(mesh_x_num): | |
for j in range(mesh_y_num): | |
vertex = [pixels_x * i, pixels_y * j] | |
neighbor_indices = tree.query_ball_point(vertex, radius, workers=-1) | |
if len(neighbor_indices) > 0: | |
pts = old_points[neighbor_indices] | |
sts = new_points[neighbor_indices] | |
ptrans_x, ptrans_y = points_transform(H_global, pts[:, 0], pts[:, 1]) | |
moves_x = sts[:, 0] - ptrans_x | |
moves_y = sts[:, 1] - ptrans_y | |
dists = np.sqrt((vertex[0] - pts[:, 0]) ** 2 + (vertex[1] - pts[:, 1]) ** 2) | |
weights_x = np.exp(-(dists ** 2) / (2 * sigma ** 2)) | |
weights_y = np.exp(-(dists ** 2) / (2 * sigma ** 2)) | |
weighted_move_x[i, j] = np.sum(weights_x * moves_x) / (np.sum(weights_x) + 0.1) | |
weighted_move_y[i, j] = np.sum(weights_y * moves_y) / (np.sum(weights_y) + 0.1) | |
x_motion_mesh = x_motion + weighted_move_x | |
y_motion_mesh = y_motion + weighted_move_y | |
''' | |
# apply median filter (f-1) on obtained motion for each vertex | |
x_motion_mesh = np.zeros((mesh_x_num, mesh_y_num), dtype=float) | |
y_motion_mesh = np.zeros((mesh_x_num, mesh_y_num), dtype=float) | |
for key in x_motion.keys(): | |
try: | |
temp_x_motion[key].sort() | |
x_motion_mesh[key] = x_motion[key]+temp_x_motion[key][len(temp_x_motion[key])//2] | |
except KeyError: | |
x_motion_mesh[key] = x_motion[key] | |
try: | |
temp_y_motion[key].sort() | |
y_motion_mesh[key] = y_motion[key]+temp_y_motion[key][len(temp_y_motion[key])//2] | |
except KeyError: | |
y_motion_mesh[key] = y_motion[key] | |
# apply second median filter (f-2) over the motion mesh for outliers | |
#x_motion_mesh = medfilt(x_motion_mesh, kernel_size=[3, 3]) | |
#y_motion_mesh = medfilt(y_motion_mesh, kernel_size=[3, 3]) | |
''' | |
return x_motion_mesh, y_motion_mesh | |
def mesh_warp_points(points, x_motion_mesh, y_motion_mesh, img_size): | |
ptrans = [] | |
mesh_x_num, mesh_y_num = x_motion_mesh.shape | |
pixels_x, pixels_y = (img_size[1]) / (mesh_x_num - 1), (img_size[0]) / (mesh_y_num - 1) | |
for pt in points: | |
i = int(pt[0] // pixels_x) | |
j = int(pt[1] // pixels_y) | |
src = [[i * pixels_x, j * pixels_y], | |
[(i + 1) * pixels_x, j * pixels_y], | |
[i * pixels_x, (j + 1) * pixels_y], | |
[(i + 1) * pixels_x, (j + 1) * pixels_y]] | |
src = np.asarray(src) | |
dst = [[x_motion_mesh[i, j], y_motion_mesh[i, j]], | |
[x_motion_mesh[i + 1, j], y_motion_mesh[i + 1, j]], | |
[x_motion_mesh[i, j + 1], y_motion_mesh[i, j + 1]], | |
[x_motion_mesh[i + 1, j + 1], y_motion_mesh[i + 1, j + 1]]] | |
dst = np.asarray(dst) | |
H, _ = cv2.findHomography(src, dst, cv2.RANSAC) | |
x, y = points_transform(H, pt[0], pt[1]) | |
ptrans.append([x, y]) | |
return np.array(ptrans) | |
def mesh_warp_frame(frame, x_motion_mesh, y_motion_mesh, resize): | |
""" | |
@param: frame is the current frame | |
@param: x_motion_mesh is the motion_mesh to | |
be warped on frame along x-direction | |
@param: y_motion_mesh is the motion mesh to | |
be warped on frame along y-direction | |
@param: resize is the desired output size (tuple of width, height) | |
Returns: | |
returns a mesh warped frame according | |
to given motion meshes x_motion_mesh, | |
y_motion_mesh, resized to the specified size | |
""" | |
map_x = np.zeros(resize, np.float32) | |
map_y = np.zeros(resize, np.float32) | |
mesh_x_num, mesh_y_num = x_motion_mesh.shape | |
pixels_x, pixels_y = (resize[1]) / (mesh_x_num - 1), (resize[0]) / (mesh_y_num - 1) | |
for i in range(mesh_x_num - 1): | |
for j in range(mesh_y_num - 1): | |
src = [[i * pixels_x, j * pixels_y], | |
[(i + 1) * pixels_x, j * pixels_y], | |
[i * pixels_x, (j + 1) * pixels_y], | |
[(i + 1) * pixels_x, (j + 1) * pixels_y]] | |
src = np.asarray(src) | |
dst = [[x_motion_mesh[i, j], y_motion_mesh[i, j]], | |
[x_motion_mesh[i + 1, j], y_motion_mesh[i + 1, j]], | |
[x_motion_mesh[i, j + 1], y_motion_mesh[i, j + 1]], | |
[x_motion_mesh[i + 1, j + 1], y_motion_mesh[i + 1, j + 1]]] | |
dst = np.asarray(dst) | |
H, _ = cv2.findHomography(src, dst, cv2.RANSAC) | |
start_x = math.ceil(pixels_x * i) | |
end_x = math.ceil(pixels_x * (i + 1)) | |
start_y = math.ceil(pixels_y * j) | |
end_y = math.ceil(pixels_y * (j + 1)) | |
x, y = np.meshgrid(range(start_x, end_x), range(start_y, end_y), indexing='ij') | |
map_x[y, x], map_y[y, x] = points_transform(H, x, y) | |
# deforms mesh and directly outputs the resized frame | |
resized_frame = cv2.remap(frame, map_x, map_y, | |
interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, | |
borderValue=(255, 255, 255)) | |
return resized_frame | |
def infer_warp_mesh_img(src, dst, model, vis=False): | |
if isinstance(src, str): | |
image1 = cv2.imread(src, cv2.IMREAD_UNCHANGED) | |
image2 = cv2.imread(dst, cv2.IMREAD_UNCHANGED) | |
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB) | |
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB) | |
elif isinstance(src, Image.Image): | |
image1 = np.array(src) | |
image2 = np.array(dst) | |
else: | |
assert isinstance(src, np.ndarray) | |
image1 = rgba_to_rgb(image1) | |
image2 = rgba_to_rgb(image2) | |
image1_padded = resize_with_aspect_ratio(image1, image2) | |
resized_image1 = cv2.resize(image1_padded, (image2.shape[1], image2.shape[0]), interpolation=cv2.INTER_AREA) | |
matches_im0, matches_im1 = infer_match([resized_image1, image2], model, vis=vis) | |
matches_im0 = matches_im0 * image1_padded.shape[0] / resized_image1.shape[0] | |
# print('Estimate Mesh Grid') | |
mesh_x, mesh_y = motion_propagate(matches_im1, matches_im0, image2.shape[:2], image1_padded.shape[:2]) | |
aligned_image = mesh_warp_frame(image1_padded, mesh_x, mesh_y, (image2.shape[0], image2.shape[1])) | |
matches_im0_from_im1 = mesh_warp_points(matches_im1, mesh_x, mesh_y, (image2.shape[1], image2.shape[0])) | |
info = compute_img_diff(aligned_image, image2, matches_im0, matches_im0_from_im1, vis=vis) | |
return aligned_image, info | |