from torch.utils.data import DataLoader import torch from model.base.geometry import Geometry from common.evaluation import Evaluator from common.logger import AverageMeter from common.logger import Logger from data import download from model import chmnet from matplotlib import pyplot as plt from matplotlib.patches import ConnectionPatch from PIL import Image import numpy as np import os import torchvision import torchvision.transforms as transforms import torchvision.transforms.functional as TF import torchvision.models as models import torch.nn as nn import torch.nn.functional as F import random import gradio as gr # Downloading the Model torchvision.datasets.utils.download_file_from_google_drive('1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6', '.', 'pas_psi.pt') # Model Initialization args = dict({ 'alpha' : [0.05, 0.1], 'benchmark':'pfpascal', 'bsz':90, 'datapath':'../Datasets_CHM', 'img_size':240, 'ktype':'psi', 'load':'pas_psi.pt', 'thres':'img' }) model = chmnet.CHMNet(args['ktype']) model.load_state_dict(torch.load(args['load'], map_location=torch.device('cpu'))) Evaluator.initialize(args['alpha']) Geometry.initialize(img_size=args['img_size']) model.eval(); # Transforms chm_transform = transforms.Compose( [transforms.Resize(args['img_size']), transforms.CenterCrop((args['img_size'], args['img_size'])), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) chm_transform_plot = transforms.Compose( [transforms.Resize(args['img_size']), transforms.CenterCrop((args['img_size'], args['img_size']))]) # A Helper Function to_np = lambda x: x.data.to('cpu').numpy() # Colors for Plotting cmap = matplotlib.cm.get_cmap('Spectral') rgba = cmap(0.5) colors = [] for k in range(49): colors.append(cmap(k/49.0)) # CHM MODEL def run_chm(source_image, target_image, selected_points, number_src_points , chm_transform, display_transform): # Convert to Tensor src_img_tnsr = chm_transform(source_image).unsqueeze(0) tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) # Selected_points = selected_points.T keypoints = torch.tensor(selected_points).unsqueeze(0) n_pts = torch.tensor(np.asarray([number_src_points])) # RUN CHM ------------------------------------------------------------------------ with torch.no_grad(): corr_matrix = model(src_img_tnsr, tgt_img_tnsr) prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) # VISUALIZATION src_points = keypoints[0].squeeze(0).squeeze(0).numpy() tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() src_points_converted = [] w, h = display_transform(source_image).size for x,y in zip(src_points[0], src_points[1]): src_points_converted.append([int(x*w/args['img_size']),int((y)*h/args['img_size'])]) src_points_converted = np.asarray(src_points_converted[:number_src_points]) tgt_points_converted = [] w, h = display_transform(target_image).size for x, y in zip(tgt_points[0], tgt_points[1]): tgt_points_converted.append([int(((x+1)/2.0)*w),int(((y+1)/2.0)*h)]) tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) tgt_grid = [] for x, y in zip(tgt_points[0], tgt_points[1]): tgt_grid.append([int(((x+1)/2.0)*7),int(((y+1)/2.0)*7)]) # PLOT fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) ax[0].imshow(display_transform(source_image)) ax[0].scatter(src_points_converted[:, 0], src_points_converted[:, 1], c=colors[:number_src_points]) ax[0].set_title('Source') ax[0].set_xticks([]) ax[0].set_yticks([]) ax[1].imshow(display_transform(target_image)) ax[1].scatter(tgt_points_converted[:, 0], tgt_points_converted[:, 1], c=colors[:number_src_points]) ax[1].set_title('Target') ax[1].set_xticks([]) ax[1].set_yticks([]) for TL in range(49): ax[0].text(x=src_points_converted[TL][0], y=src_points_converted[TL][1], s=str(TL), fontdict=dict(color='red', size=10)) for TL in range(49): ax[1].text(x=tgt_points_converted[TL][0], y=tgt_points_converted[TL][1], s=f'{str(TL)}', fontdict=dict(color='orange', size=8)) plt.tight_layout() fig.suptitle('CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ', fontsize=16) return fig # Wrapper def generate_correspondences(sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100): A = np.linspace(min_x, max_x, 7) B = np.linspace(min_y, max_y, 7) point_list = list(product(A, B)) new_points = np.asarray(point_list, dtype=np.float64).T return run_chm(sousrce_image, target_image, selected_points=new_points, number_src_points=49, chm_transform=chm_transform, display_transform=chm_transform_plot) # GRADIO APP iface = gr.Interface(fn=generate_correspondences, inputs=[gr.inputs.Image(shape=(240, 240), type='pil'), gr.inputs.Image(shape=(240, 240), type='pil'), gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='MinX'), gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='MaxX'), gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='MinY'), gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='MaxY')], outputs="plot") iface.launch()