Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import kornia as K | |
import cv2 | |
import numpy as np | |
import matplotlib | |
import matplotlib.pyplot as plt | |
matplotlib.use('Agg') | |
from scipy.cluster.vq import kmeans,vq,whiten | |
def get_coordinates_from_mask(mask_in): | |
x_y = np.where(mask_in != [0,0,0,255])[:2] | |
x_y = np.column_stack((x_y[1], x_y[0])) | |
x_y = np.float32(x_y) | |
centroids,_ = kmeans(x_y,4) | |
centroids = np.int64(centroids) | |
return centroids | |
def get_top_bottom_coordinates(coords): | |
top_coord = min(coords, key=lambda x : x[1]) | |
bottom_coord = max(coords, key=lambda x : x[1]) | |
return top_coord, bottom_coord | |
def sort_centroids_clockwise(centroids: np.ndarray): | |
c_list = centroids.tolist() | |
c_list.sort(key = lambda y : y[0]) | |
left_coords = c_list[:2] | |
right_coords = c_list[-2:] | |
top_left, bottom_left = get_top_bottom_coordinates(left_coords) | |
top_right, bottom_right = get_top_bottom_coordinates(right_coords) | |
return top_left, top_right, bottom_right, bottom_left | |
def infer(image_input, dst_height:str, dst_width:str): | |
image_in = image_input["image"] | |
mask_in = image_input["mask"] | |
torch_img = K.image_to_tensor(image_in) | |
centroids = get_coordinates_from_mask(mask_in) | |
ordered_src_coords = sort_centroids_clockwise(centroids) | |
# the source points are the region to crop corners | |
points_src = torch.tensor([list(ordered_src_coords)], dtype=torch.float32) | |
# the destination points are the image vertexes | |
h, w = dst_height, dst_width # destination size | |
points_dst = torch.tensor([[ | |
[0., 0.], [w - 1., 0.], [w - 1., h - 1.], [0., h - 1.], | |
]], dtype=torch.float32) | |
# compute perspective transform | |
M: torch.tensor = K.geometry.get_perspective_transform(points_src, points_dst) | |
# warp the original image by the found transform | |
torch_img = torch.stack([torch_img],) | |
img_warp: torch.tensor = K.geometry.warp_perspective(torch_img.float(), M, dsize=(h, w)) | |
# convert back to numpy | |
img_np = K.tensor_to_image(torch_img.byte()) | |
img_warp_np: np.ndarray = K.tensor_to_image(img_warp.byte()) | |
# draw points into original image | |
for i in range(4): | |
center = tuple(points_src[0, i].long().numpy()) | |
img_np = cv2.circle(img_np.copy(), center, 5, (0, 255, 0), -1) | |
# create the plot | |
fig, axs = plt.subplots(1, 2, figsize=(16, 10)) | |
axs = axs.ravel() | |
axs[0].axis('off') | |
axs[0].set_title('image source') | |
axs[0].imshow(img_np) | |
axs[1].axis('off') | |
axs[1].set_title('image destination') | |
axs[1].imshow(img_warp_np) | |
return fig | |
description = """Homography Warping""" | |
example_mask = np.empty((327,600,4)) | |
example_mask[:] = [0,0,0,255] | |
example_image_dict = {"image": "bruce.png", "mask": example_mask} | |
Iface = gr.Interface( | |
fn=infer, | |
inputs=[gr.components.Image(tool="sketch"), | |
gr.components.Textbox(label="Destination Height"), | |
gr.components.Textbox(label="Destination Width"), | |
], | |
outputs=gr.components.Plot(), | |
examples=[["bruce.png", example_mask], "64", "128"], | |
title="Homography Warping", | |
description=description, | |
).launch() |