|
from torch.utils.data import DataLoader |
|
import torch |
|
from model.base.geometry import Geometry |
|
from common.evaluation import Evaluator |
|
from common.logger import AverageMeter |
|
from common.logger import Logger |
|
from data import download |
|
from model import chmnet |
|
from itertools import product |
|
import matplotlib |
|
import matplotlib.patches as patches |
|
from matplotlib.patches import ConnectionPatch |
|
from matplotlib import pyplot as plt |
|
from PIL import Image |
|
import numpy as np |
|
import os |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
import torchvision.transforms.functional as TF |
|
import torchvision.models as models |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import random |
|
import gradio as gr |
|
|
|
|
|
torchvision.datasets.utils.download_file_from_google_drive('1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6', '.', 'pas_psi.pt') |
|
|
|
|
|
args = dict({ |
|
'alpha' : [0.05, 0.1], |
|
'benchmark':'pfpascal', |
|
'bsz':90, |
|
'datapath':'../Datasets_CHM', |
|
'img_size':240, |
|
'ktype':'psi', |
|
'load':'pas_psi.pt', |
|
'thres':'img' |
|
}) |
|
|
|
model = chmnet.CHMNet(args['ktype']) |
|
model.load_state_dict(torch.load(args['load'], map_location=torch.device('cpu'))) |
|
Evaluator.initialize(args['alpha']) |
|
Geometry.initialize(img_size=args['img_size']) |
|
model.eval(); |
|
|
|
|
|
|
|
chm_transform = transforms.Compose( |
|
[transforms.Resize(args['img_size']), |
|
transforms.CenterCrop((args['img_size'], args['img_size'])), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) |
|
|
|
chm_transform_plot = transforms.Compose( |
|
[transforms.Resize(args['img_size']), |
|
transforms.CenterCrop((args['img_size'], args['img_size']))]) |
|
|
|
|
|
to_np = lambda x: x.data.to('cpu').numpy() |
|
|
|
|
|
cmap = matplotlib.cm.get_cmap('Spectral') |
|
rgba = cmap(0.5) |
|
colors = [] |
|
for k in range(49): |
|
colors.append(cmap(k/49.0)) |
|
|
|
|
|
|
|
def run_chm(source_image, target_image, selected_points, number_src_points , chm_transform, display_transform): |
|
|
|
src_img_tnsr = chm_transform(source_image).unsqueeze(0) |
|
tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) |
|
|
|
|
|
keypoints = torch.tensor(selected_points).unsqueeze(0) |
|
n_pts = torch.tensor(np.asarray([number_src_points])) |
|
|
|
|
|
with torch.no_grad(): |
|
corr_matrix = model(src_img_tnsr, tgt_img_tnsr) |
|
prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) |
|
|
|
|
|
src_points = keypoints[0].squeeze(0).squeeze(0).numpy() |
|
tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() |
|
|
|
src_points_converted = [] |
|
w, h = display_transform(source_image).size |
|
|
|
for x,y in zip(src_points[0], src_points[1]): |
|
src_points_converted.append([int(x*w/args['img_size']),int((y)*h/args['img_size'])]) |
|
|
|
src_points_converted = np.asarray(src_points_converted[:number_src_points]) |
|
tgt_points_converted = [] |
|
|
|
w, h = display_transform(target_image).size |
|
for x, y in zip(tgt_points[0], tgt_points[1]): |
|
tgt_points_converted.append([int(((x+1)/2.0)*w),int(((y+1)/2.0)*h)]) |
|
|
|
tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) |
|
|
|
tgt_grid = [] |
|
|
|
for x, y in zip(tgt_points[0], tgt_points[1]): |
|
tgt_grid.append([int(((x+1)/2.0)*7),int(((y+1)/2.0)*7)]) |
|
|
|
|
|
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) |
|
|
|
ax[0].imshow(display_transform(source_image)) |
|
ax[0].scatter(src_points_converted[:, 0], src_points_converted[:, 1], c=colors[:number_src_points]) |
|
ax[0].set_title('Source') |
|
ax[0].set_xticks([]) |
|
ax[0].set_yticks([]) |
|
|
|
ax[1].imshow(display_transform(target_image)) |
|
ax[1].scatter(tgt_points_converted[:, 0], tgt_points_converted[:, 1], c=colors[:number_src_points]) |
|
ax[1].set_title('Target') |
|
ax[1].set_xticks([]) |
|
ax[1].set_yticks([]) |
|
|
|
for TL in range(49): |
|
ax[0].text(x=src_points_converted[TL][0], y=src_points_converted[TL][1], s=str(TL), fontdict=dict(color='red', size=11)) |
|
|
|
for TL in range(49): |
|
ax[1].text(x=tgt_points_converted[TL][0], y=tgt_points_converted[TL][1], s=f'{str(TL)}', fontdict=dict(color='orange', size=11)) |
|
|
|
plt.tight_layout() |
|
fig.suptitle('CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ', fontsize=16) |
|
return fig |
|
|
|
|
|
|
|
def generate_correspondences(sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100): |
|
A = np.linspace(min_x, max_x, 7) |
|
B = np.linspace(min_y, max_y, 7) |
|
point_list = list(product(A, B)) |
|
new_points = np.asarray(point_list, dtype=np.float64).T |
|
return run_chm(sousrce_image, target_image, selected_points=new_points, number_src_points=49, chm_transform=chm_transform, display_transform=chm_transform_plot) |
|
|
|
|
|
|
|
title = "Correspondence Matching with Convolutional Hough Matching Networks " |
|
description = "Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid." |
|
article = "<p style='text-align: center'><a href='https://github.com/juhongm999/chm' target='_blank'>Original Github Repo</a></p>" |
|
|
|
iface = gr.Interface(fn=generate_correspondences, |
|
inputs=[gr.inputs.Image(shape=(240, 240), type='pil'), |
|
gr.inputs.Image(shape=(240, 240), type='pil'), |
|
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='Min X'), |
|
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='Max X'), |
|
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='Min Y'), |
|
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='Max Y')], outputs="plot", enable_queue=True, title=title, |
|
description=description, |
|
article=article, |
|
examples=[['sample1.jpeg', 'sample2.jpeg', 15, 215, 15, 215]]) |
|
iface.launch() |