Spaces:
Runtime error
Runtime error
Johannes
commited on
Commit
·
ebf587a
1
Parent(s):
f9a9025
update
Browse files- app.py +105 -0
- bruce.png +0 -0
- requirements.txt +3 -0
app.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import kornia as K
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import matplotlib
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
matplotlib.use('Agg')
|
10 |
+
from scipy.cluster.vq import kmeans,vq,whiten
|
11 |
+
|
12 |
+
def get_coordinates_from_mask(mask_in):
|
13 |
+
x_y = np.where(mask_in != [0,0,0,255])[:2]
|
14 |
+
x_y = np.column_stack((x_y[1], x_y[0]))
|
15 |
+
x_y = np.float32(x_y)
|
16 |
+
centroids,_ = kmeans(x_y,4)
|
17 |
+
centroids = np.int64(centroids)
|
18 |
+
|
19 |
+
return centroids
|
20 |
+
|
21 |
+
def get_top_bottom_coordinates(coords: list[list[int,int]]) -> (list[int,int],list[int,int]):
|
22 |
+
top_coord = min(coords, key=lambda x : x[1])
|
23 |
+
bottom_coord = max(coords, key=lambda x : x[1])
|
24 |
+
|
25 |
+
return top_coord, bottom_coord
|
26 |
+
|
27 |
+
|
28 |
+
def sort_centroids_clockwise(centroids: np.ndarray):
|
29 |
+
c_list = centroids.tolist()
|
30 |
+
c_list.sort(key = lambda y : y[0])
|
31 |
+
|
32 |
+
left_coords = c_list[:2]
|
33 |
+
right_coords = c_list[-2:]
|
34 |
+
|
35 |
+
top_left, bottom_left = get_top_bottom_coordinates(left_coords)
|
36 |
+
top_right, bottom_right = get_top_bottom_coordinates(right_coords)
|
37 |
+
|
38 |
+
return top_left, top_right, bottom_right, bottom_left
|
39 |
+
|
40 |
+
|
41 |
+
def infer(image_input, dst_height:str, dst_width:str):
|
42 |
+
image_in = image_input["image"]
|
43 |
+
mask_in = image_input["mask"]
|
44 |
+
torch_img = K.image_to_tensor(image_in)
|
45 |
+
|
46 |
+
centroids = get_coordinates_from_mask(mask_in)
|
47 |
+
ordered_src_coords = sort_centroids_clockwise(centroids)
|
48 |
+
# the source points are the region to crop corners
|
49 |
+
points_src = torch.tensor([list(ordered_src_coords)], dtype=torch.float32)
|
50 |
+
|
51 |
+
# the destination points are the image vertexes
|
52 |
+
h, w = dst_height, dst_width # destination size
|
53 |
+
points_dst = torch.tensor([[
|
54 |
+
[0., 0.], [w - 1., 0.], [w - 1., h - 1.], [0., h - 1.],
|
55 |
+
]], dtype=torch.float32)
|
56 |
+
|
57 |
+
# compute perspective transform
|
58 |
+
M: torch.tensor = K.geometry.get_perspective_transform(points_src, points_dst)
|
59 |
+
|
60 |
+
# warp the original image by the found transform
|
61 |
+
torch_img = torch.stack([torch_img],)
|
62 |
+
img_warp: torch.tensor = K.geometry.warp_perspective(torch_img.float(), M, dsize=(h, w))
|
63 |
+
|
64 |
+
|
65 |
+
# convert back to numpy
|
66 |
+
img_np = K.tensor_to_image(torch_img.byte())
|
67 |
+
img_warp_np: np.ndarray = K.tensor_to_image(img_warp.byte())
|
68 |
+
|
69 |
+
# draw points into original image
|
70 |
+
for i in range(4):
|
71 |
+
center = tuple(points_src[0, i].long().numpy())
|
72 |
+
img_np = cv2.circle(img_np.copy(), center, 5, (0, 255, 0), -1)
|
73 |
+
|
74 |
+
# create the plot
|
75 |
+
fig, axs = plt.subplots(1, 2, figsize=(16, 10))
|
76 |
+
axs = axs.ravel()
|
77 |
+
|
78 |
+
axs[0].axis('off')
|
79 |
+
axs[0].set_title('image source')
|
80 |
+
axs[0].imshow(img_np)
|
81 |
+
|
82 |
+
axs[1].axis('off')
|
83 |
+
axs[1].set_title('image destination')
|
84 |
+
axs[1].imshow(img_warp_np)
|
85 |
+
|
86 |
+
return fig
|
87 |
+
|
88 |
+
|
89 |
+
description = """Homography Warping"""
|
90 |
+
|
91 |
+
example_mask = np.empty((327,600,4))
|
92 |
+
example_mask[:] = [0,0,0,255]
|
93 |
+
example_image_dict = {"image": "bruce.png", "mask": example_mask}
|
94 |
+
|
95 |
+
Iface = gr.Interface(
|
96 |
+
fn=infer,
|
97 |
+
inputs=[gr.components.Image(tool="sketch"),
|
98 |
+
gr.components.Textbox(label="Destination Height"),
|
99 |
+
gr.components.Textbox(label="Destination Width"),
|
100 |
+
],
|
101 |
+
outputs=gr.components.Plot(),
|
102 |
+
examples=[["bruce.png", example_mask], "64", "128"],
|
103 |
+
title="Homography Warping",
|
104 |
+
description=description,
|
105 |
+
).launch()
|
bruce.png
ADDED
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
kornia
|
2 |
+
opencv-python
|
3 |
+
matplotlib
|