import gradio as gr import torch import kornia as K import cv2 import numpy as np import matplotlib import matplotlib.pyplot as plt matplotlib.use('Agg') from scipy.cluster.vq import kmeans,vq,whiten def get_coordinates_from_mask(mask_in): x_y = np.where(mask_in != [0,0,0,255])[:2] x_y = np.column_stack((x_y[1], x_y[0])) x_y = np.float32(x_y) centroids,_ = kmeans(x_y,4) centroids = np.int64(centroids) return centroids def get_top_bottom_coordinates(coords: list[list[int,int]]) -> tuple(list[int,int],list[int,int]): top_coord = min(coords, key=lambda x : x[1]) bottom_coord = max(coords, key=lambda x : x[1]) return top_coord, bottom_coord def sort_centroids_clockwise(centroids: np.ndarray): c_list = centroids.tolist() c_list.sort(key = lambda y : y[0]) left_coords = c_list[:2] right_coords = c_list[-2:] top_left, bottom_left = get_top_bottom_coordinates(left_coords) top_right, bottom_right = get_top_bottom_coordinates(right_coords) return top_left, top_right, bottom_right, bottom_left def infer(image_input, dst_height:str, dst_width:str): image_in = image_input["image"] mask_in = image_input["mask"] torch_img = K.image_to_tensor(image_in) centroids = get_coordinates_from_mask(mask_in) ordered_src_coords = sort_centroids_clockwise(centroids) # the source points are the region to crop corners points_src = torch.tensor([list(ordered_src_coords)], dtype=torch.float32) # the destination points are the image vertexes h, w = dst_height, dst_width # destination size points_dst = torch.tensor([[ [0., 0.], [w - 1., 0.], [w - 1., h - 1.], [0., h - 1.], ]], dtype=torch.float32) # compute perspective transform M: torch.tensor = K.geometry.get_perspective_transform(points_src, points_dst) # warp the original image by the found transform torch_img = torch.stack([torch_img],) img_warp: torch.tensor = K.geometry.warp_perspective(torch_img.float(), M, dsize=(h, w)) # convert back to numpy img_np = K.tensor_to_image(torch_img.byte()) img_warp_np: np.ndarray = K.tensor_to_image(img_warp.byte()) # draw points into original image for i in range(4): center = tuple(points_src[0, i].long().numpy()) img_np = cv2.circle(img_np.copy(), center, 5, (0, 255, 0), -1) # create the plot fig, axs = plt.subplots(1, 2, figsize=(16, 10)) axs = axs.ravel() axs[0].axis('off') axs[0].set_title('image source') axs[0].imshow(img_np) axs[1].axis('off') axs[1].set_title('image destination') axs[1].imshow(img_warp_np) return fig description = """Homography Warping""" example_mask = np.empty((327,600,4)) example_mask[:] = [0,0,0,255] example_image_dict = {"image": "bruce.png", "mask": example_mask} Iface = gr.Interface( fn=infer, inputs=[gr.components.Image(tool="sketch"), gr.components.Textbox(label="Destination Height"), gr.components.Textbox(label="Destination Width"), ], outputs=gr.components.Plot(), examples=[["bruce.png", example_mask], "64", "128"], title="Homography Warping", description=description, ).launch()