Spaces:
Running
Running
''' | |
COTR demo for homography estimation | |
''' | |
import argparse | |
import os | |
import time | |
import cv2 | |
import numpy as np | |
import torch | |
import imageio | |
import matplotlib.pyplot as plt | |
from COTR.utils import utils, debug_utils | |
from COTR.models import build_model | |
from COTR.options.options import * | |
from COTR.options.options_utils import * | |
from COTR.inference.inference_helper import triangulate_corr | |
from COTR.inference.sparse_engine import SparseEngine | |
utils.fix_randomness(0) | |
torch.set_grad_enabled(False) | |
def main(opt): | |
model = build_model(opt) | |
model = model.cuda() | |
weights = torch.load(opt.load_weights_path, map_location='cpu')['model_state_dict'] | |
utils.safe_load_weights(model, weights) | |
model = model.eval() | |
img_a = imageio.imread('./sample_data/imgs/paint_1.JPG', pilmode='RGB') | |
img_b = imageio.imread('./sample_data/imgs/paint_2.jpg', pilmode='RGB') | |
rep_img = imageio.imread('./sample_data/imgs/Meisje_met_de_parel.jpg', pilmode='RGB') | |
rep_mask = np.ones(rep_img.shape[:2]) | |
lu_corner = [932, 1025] | |
ru_corner = [2469, 901] | |
lb_corner = [908, 2927] | |
rb_corner = [2436, 3080] | |
queries = np.array([lu_corner, ru_corner, lb_corner, rb_corner]).astype(np.float32) | |
rep_coord = np.array([[0, 0], [rep_img.shape[1], 0], [0, rep_img.shape[0]], [rep_img.shape[1], rep_img.shape[0]]]).astype(np.float32) | |
engine = SparseEngine(model, 32, mode='stretching') | |
corrs = engine.cotr_corr_multiscale(img_a, img_b, np.linspace(0.5, 0.0625, 4), 1, queries_a=queries, force=True) | |
T = cv2.getPerspectiveTransform(rep_coord, corrs[:, 2:].astype(np.float32)) | |
vmask = cv2.warpPerspective(rep_mask, T, (img_b.shape[1], img_b.shape[0])) > 0 | |
warped = cv2.warpPerspective(rep_img, T, (img_b.shape[1], img_b.shape[0])) | |
out = warped * vmask[..., None] + img_b * (~vmask[..., None]) | |
f, axarr = plt.subplots(1, 4) | |
axarr[0].imshow(rep_img) | |
axarr[0].title.set_text('Virtual Paint') | |
axarr[0].axis('off') | |
axarr[1].imshow(img_a) | |
axarr[1].title.set_text('Annotated Frame') | |
axarr[1].axis('off') | |
axarr[2].imshow(img_b) | |
axarr[2].title.set_text('Target Frame') | |
axarr[2].axis('off') | |
axarr[3].imshow(out) | |
axarr[3].title.set_text('Overlay') | |
axarr[3].axis('off') | |
plt.show() | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
set_COTR_arguments(parser) | |
parser.add_argument('--out_dir', type=str, default=general_config['out'], help='out directory') | |
parser.add_argument('--load_weights', type=str, default=None, help='load a pretrained set of weights, you need to provide the model id') | |
opt = parser.parse_args() | |
opt.command = ' '.join(sys.argv) | |
layer_2_channels = {'layer1': 256, | |
'layer2': 512, | |
'layer3': 1024, | |
'layer4': 2048, } | |
opt.dim_feedforward = layer_2_channels[opt.layer] | |
if opt.load_weights: | |
opt.load_weights_path = os.path.join(opt.out_dir, opt.load_weights, 'checkpoint.pth.tar') | |
print_opt(opt) | |
main(opt) | |