vggsfm / vggsfm_code /hf_demo.py
JianyuanWang's picture
close center order back
7c6db12
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import time
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.cuda.amp import autocast
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from lightglue import LightGlue, SuperPoint, SIFT, ALIKED
import pycolmap
# from visdom import Visdom
from vggsfm.datasets.demo_loader import DemoLoader
from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras
try:
import poselib
from vggsfm.two_view_geo.estimate_preliminary import estimate_preliminary_cameras_poselib
print("Poselib is available")
except:
print("Poselib is not installed. Please disable use_poselib")
from vggsfm.utils.utils import (
set_seed_and_print,
farthest_point_sampling,
calculate_index_mappings,
switch_tensor_order,
)
def demo_fn(cfg):
OmegaConf.set_struct(cfg, False)
# Print configuration
print("Model Config:", OmegaConf.to_yaml(cfg))
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
# Set seed
seed_all_random_engines(cfg.seed)
# Model instantiation
model = instantiate(cfg.MODEL, _recursive_=False, cfg=cfg)
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Prepare test dataset
test_dataset = DemoLoader(
SCENE_DIR=cfg.SCENE_DIR, img_size=cfg.img_size, normalize_cameras=False, load_gt=cfg.load_gt, cfg=cfg
)
# if cfg.resume_ckpt:
_VGGSFM_URL = "https://huggingface.co/facebook/VGGSfM/resolve/main/vggsfm_v2_0_0.bin"
# Reload model
checkpoint = torch.hub.load_state_dict_from_url(_VGGSFM_URL)
model.load_state_dict(checkpoint, strict=True)
print(f"Successfully resumed from {_VGGSFM_URL}")
sequence_list = test_dataset.sequence_list
for seq_name in sequence_list:
print("*" * 50 + f" Testing on Scene {seq_name} " + "*" * 50)
# Load the data
batch, image_paths = test_dataset.get_data(sequence_name=seq_name, return_path=True)
# Send to GPU
images = batch["image"].to(device)
crop_params = batch["crop_params"].to(device)
# Unsqueeze to have batch size = 1
images = images.unsqueeze(0)
crop_params = crop_params.unsqueeze(0)
batch_size = len(images)
with torch.no_grad():
# Run the model
assert cfg.mixed_precision in ("None", "bf16", "fp16")
if cfg.mixed_precision == "None":
dtype = torch.float32
elif cfg.mixed_precision == "bf16":
dtype = torch.bfloat16
elif cfg.mixed_precision == "fp16":
dtype = torch.float16
else:
raise NotImplementedError(f"dtype {cfg.mixed_precision} is not supported now")
predictions = run_one_scene(
model,
images,
crop_params=crop_params,
query_frame_num=cfg.query_frame_num,
image_paths=image_paths,
dtype=dtype,
cfg=cfg,
)
pred_cameras_PT3D = predictions["pred_cameras_PT3D"]
return predictions
def run_one_scene(model, images, crop_params=None, query_frame_num=3, image_paths=None, dtype=None, cfg=None):
"""
images have been normalized to the range [0, 1] instead of [0, 255]
"""
batch_num, frame_num, image_dim, height, width = images.shape
device = images.device
reshaped_image = images.reshape(batch_num * frame_num, image_dim, height, width)
predictions = {}
extra_dict = {}
camera_predictor = model.camera_predictor
track_predictor = model.track_predictor
triangulator = model.triangulator
# Find the query frames
# First use DINO to find the most common frame among all the input frames
# i.e., the one has highest (average) cosine similarity to all others
# Then use farthest_point_sampling to find the next ones
# The number of query frames is determined by query_frame_num
with autocast(dtype=dtype):
query_frame_indexes = find_query_frame_indexes(reshaped_image, camera_predictor, frame_num)
raw_image_paths = image_paths
image_paths = [os.path.basename(imgpath) for imgpath in image_paths]
if cfg.center_order:
# The code below switchs the first frame (frame 0) to the most common frame
center_frame_index = query_frame_indexes[0]
center_order = calculate_index_mappings(center_frame_index, frame_num, device=device)
images, crop_params = switch_tensor_order([images, crop_params], center_order, dim=1)
reshaped_image = switch_tensor_order([reshaped_image], center_order, dim=0)[0]
image_paths = [image_paths[i] for i in center_order.cpu().numpy().tolist()]
# Also update query_frame_indexes:
query_frame_indexes = [center_frame_index if x == 0 else x for x in query_frame_indexes]
query_frame_indexes[0] = 0
# only pick query_frame_num
query_frame_indexes = query_frame_indexes[:query_frame_num]
# Prepare image feature maps for tracker
fmaps_for_tracker = track_predictor.process_images_to_fmaps(images)
# Predict tracks
with autocast(dtype=dtype):
pred_track, pred_vis, pred_score = predict_tracks(
cfg.query_method,
cfg.max_query_pts,
track_predictor,
images,
fmaps_for_tracker,
query_frame_indexes,
frame_num,
device,
cfg,
)
if cfg.comple_nonvis:
pred_track, pred_vis, pred_score = comple_nonvis_frames(
track_predictor,
images,
fmaps_for_tracker,
frame_num,
device,
pred_track,
pred_vis,
pred_score,
200,
cfg=cfg,
)
torch.cuda.empty_cache()
# If necessary, force all the predictions at the padding areas as non-visible
if crop_params is not None:
boundaries = crop_params[:, :, -4:-2].abs().to(device)
boundaries = torch.cat([boundaries, reshaped_image.shape[-1] - boundaries], dim=-1)
hvis = torch.logical_and(
pred_track[..., 1] >= boundaries[:, :, 1:2], pred_track[..., 1] <= boundaries[:, :, 3:4]
)
wvis = torch.logical_and(
pred_track[..., 0] >= boundaries[:, :, 0:1], pred_track[..., 0] <= boundaries[:, :, 2:3]
)
force_vis = torch.logical_and(hvis, wvis)
pred_vis = pred_vis * force_vis.float()
# TODO: plot 2D matches
if cfg.use_poselib:
estimate_preliminary_cameras_fn = estimate_preliminary_cameras_poselib
else:
estimate_preliminary_cameras_fn = estimate_preliminary_cameras
# Estimate preliminary_cameras by recovering fundamental/essential/homography matrix from 2D matches
# By default, we use fundamental matrix estimation with 7p/8p+LORANSAC
# All the operations are batched and differentiable (if necessary)
# except when you enable use_poselib to save GPU memory
_, preliminary_dict = estimate_preliminary_cameras_fn(
pred_track,
pred_vis,
width,
height,
tracks_score=pred_score,
max_error=cfg.fmat_thres,
loopresidual=True,
# max_ransac_iters=cfg.max_ransac_iters,
)
pose_predictions = camera_predictor(reshaped_image, batch_size=batch_num)
pred_cameras = pose_predictions["pred_cameras"]
# Conduct Triangulation and Bundle Adjustment
(
BA_cameras_PT3D,
extrinsics_opencv,
intrinsics_opencv,
points3D,
points3D_rgb,
reconstruction,
valid_frame_mask,
) = triangulator(
pred_cameras,
pred_track,
pred_vis,
images,
preliminary_dict,
image_paths=image_paths,
crop_params=crop_params,
pred_score=pred_score,
fmat_thres=cfg.fmat_thres,
BA_iters=cfg.BA_iters,
max_reproj_error = cfg.max_reproj_error,
init_max_reproj_error=cfg.init_max_reproj_error,
cfg=cfg,
)
# if cfg.center_order:
# # NOTE we changed the image order previously, now we need to switch it back
# BA_cameras_PT3D = BA_cameras_PT3D[center_order]
# extrinsics_opencv = extrinsics_opencv[center_order]
# intrinsics_opencv = intrinsics_opencv[center_order]
if cfg.filter_invalid_frame:
raw_image_paths = np.array(raw_image_paths)[valid_frame_mask.cpu().numpy().tolist()].tolist()
images = images[0][valid_frame_mask]
predictions["pred_cameras_PT3D"] = BA_cameras_PT3D
predictions["extrinsics_opencv"] = extrinsics_opencv
predictions["intrinsics_opencv"] = intrinsics_opencv
predictions["points3D"] = points3D
predictions["points3D_rgb"] = points3D_rgb
predictions["reconstruction"] = reconstruction
predictions["images"] = images
predictions["raw_image_paths"] = raw_image_paths
return predictions
def predict_tracks(
query_method,
max_query_pts,
track_predictor,
images,
fmaps_for_tracker,
query_frame_indexes,
frame_num,
device,
cfg=None,
):
pred_track_list = []
pred_vis_list = []
pred_score_list = []
for query_index in query_frame_indexes:
print(f"Predicting tracks with query_index = {query_index}")
# Find query_points at the query frame
query_points = get_query_points(images[:, query_index], query_method, max_query_pts)
# Switch so that query_index frame stays at the first frame
# This largely simplifies the code structure of tracker
new_order = calculate_index_mappings(query_index, frame_num, device=device)
images_feed, fmaps_feed = switch_tensor_order([images, fmaps_for_tracker], new_order)
# Feed into track predictor
fine_pred_track, _, pred_vis, pred_score = track_predictor(images_feed, query_points, fmaps=fmaps_feed)
# Switch back the predictions
fine_pred_track, pred_vis, pred_score = switch_tensor_order([fine_pred_track, pred_vis, pred_score], new_order)
# Append predictions for different queries
pred_track_list.append(fine_pred_track)
pred_vis_list.append(pred_vis)
pred_score_list.append(pred_score)
pred_track = torch.cat(pred_track_list, dim=2)
pred_vis = torch.cat(pred_vis_list, dim=2)
pred_score = torch.cat(pred_score_list, dim=2)
return pred_track, pred_vis, pred_score
def comple_nonvis_frames(
track_predictor,
images,
fmaps_for_tracker,
frame_num,
device,
pred_track,
pred_vis,
pred_score,
min_vis=500,
cfg=None,
):
# if a frame has too few visible inlier, use it as a query
non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist()
last_query = -1
while len(non_vis_frames) > 0:
print("Processing non visible frames")
print(non_vis_frames)
if non_vis_frames[0] == last_query:
print("The non vis frame still does not has enough 2D matches")
pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks(
"sp+sift+aliked",
cfg.max_query_pts // 2,
track_predictor,
images,
fmaps_for_tracker,
non_vis_frames,
frame_num,
device,
cfg,
)
# concat predictions
pred_track = torch.cat([pred_track, pred_track_comple], dim=2)
pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2)
pred_score = torch.cat([pred_score, pred_score_comple], dim=2)
break
non_vis_query_list = [non_vis_frames[0]]
last_query = non_vis_frames[0]
pred_track_comple, pred_vis_comple, pred_score_comple = predict_tracks(
cfg.query_method,
cfg.max_query_pts,
track_predictor,
images,
fmaps_for_tracker,
non_vis_query_list,
frame_num,
device,
cfg,
)
# concat predictions
pred_track = torch.cat([pred_track, pred_track_comple], dim=2)
pred_vis = torch.cat([pred_vis, pred_vis_comple], dim=2)
pred_score = torch.cat([pred_score, pred_score_comple], dim=2)
non_vis_frames = torch.nonzero((pred_vis.squeeze(0) > 0.05).sum(-1) < min_vis).squeeze(-1).tolist()
return pred_track, pred_vis, pred_score
def find_query_frame_indexes(reshaped_image, camera_predictor, query_frame_num, image_size=336):
# Downsample image to image_size x image_size
# because we found it is unnecessary to use high resolution
rgbs = F.interpolate(reshaped_image, (image_size, image_size), mode="bilinear", align_corners=True)
rgbs = camera_predictor._resnet_normalize_image(rgbs)
# Get the image features (patch level)
frame_feat = camera_predictor.backbone(rgbs, is_training=True)
frame_feat = frame_feat["x_norm_patchtokens"]
frame_feat_norm = F.normalize(frame_feat, p=2, dim=1)
# Compute the similiarty matrix
frame_feat_norm = frame_feat_norm.permute(1, 0, 2)
similarity_matrix = torch.bmm(frame_feat_norm, frame_feat_norm.transpose(-1, -2))
similarity_matrix = similarity_matrix.mean(dim=0)
distance_matrix = 100 - similarity_matrix.clone()
# Ignore self-pairing
similarity_matrix.fill_diagonal_(-100)
similarity_sum = similarity_matrix.sum(dim=1)
# Find the most common frame
most_common_frame_index = torch.argmax(similarity_sum).item()
# Conduct FPS sampling
# Starting from the most_common_frame_index,
# try to find the farthest frame,
# then the farthest to the last found frame
# (frames are not allowed to be found twice)
fps_idx = farthest_point_sampling(distance_matrix, query_frame_num, most_common_frame_index)
return fps_idx
def get_query_points(query_image, query_method, max_query_num=4096, det_thres=0.005):
# Run superpoint and sift on the target frame
# Feel free to modify for your own
methods = query_method.split("+")
pred_points = []
for method in methods:
if "sp" in method:
extractor = SuperPoint(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval()
elif "sift" in method:
extractor = SIFT(max_num_keypoints=max_query_num).cuda().eval()
elif "aliked" in method:
extractor = ALIKED(max_num_keypoints=max_query_num, detection_threshold=det_thres).cuda().eval()
else:
raise NotImplementedError(f"query method {method} is not supprted now")
query_points = extractor.extract(query_image)["keypoints"]
pred_points.append(query_points)
query_points = torch.cat(pred_points, dim=1)
if query_points.shape[1] > max_query_num:
random_point_indices = torch.randperm(query_points.shape[1])[:max_query_num]
query_points = query_points[:, random_point_indices, :]
return query_points
def seed_all_random_engines(seed: int) -> None:
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)