mischeiwiller commited on
Commit
8b82a8c
1 Parent(s): b8884cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -32
app.py CHANGED
@@ -6,34 +6,11 @@ import numpy as np
6
  import matplotlib.pyplot as plt
7
  from scipy.cluster.vq import kmeans
8
 
9
- def get_coordinates_from_mask(mask_in):
10
- x_y = np.where(mask_in != [0,0,0,255])[:2]
11
- x_y = np.column_stack((x_y[1], x_y[0]))
12
- x_y = np.float32(x_y)
13
- centroids,_ = kmeans(x_y,4)
14
- centroids = np.int64(centroids)
15
- return centroids
16
 
17
- def get_top_bottom_coordinates(coords):
18
- top_coord = min(coords, key=lambda x : x[1])
19
- bottom_coord = max(coords, key=lambda x : x[1])
20
- return top_coord, bottom_coord
21
-
22
- def sort_centroids_clockwise(centroids: np.ndarray):
23
- c_list = centroids.tolist()
24
- c_list.sort(key = lambda y : y[0])
25
-
26
- left_coords = c_list[:2]
27
- right_coords = c_list[-2:]
28
-
29
- top_left, bottom_left = get_top_bottom_coordinates(left_coords)
30
- top_right, bottom_right = get_top_bottom_coordinates(right_coords)
31
-
32
- return top_left, top_right, bottom_right, bottom_left
33
-
34
- def infer(image_input, dst_height: str, dst_width: str):
35
- image_in = image_input["image"]
36
- mask_in = image_input["mask"]
37
  torch_img = K.utils.image_to_tensor(image_in).float() / 255.0
38
 
39
  centroids = get_coordinates_from_mask(mask_in)
@@ -78,14 +55,14 @@ description = """In this space you can warp an image using perspective transform
78
 
79
  example_mask = np.zeros((327, 600, 4), dtype=np.uint8)
80
  example_mask[:, :, 3] = 255
81
- example_image_dict = {"image": "bruce.png", "mask": example_mask}
82
 
83
  with gr.Blocks() as demo:
84
  gr.Markdown("# Homography Warping")
85
  gr.Markdown(description)
86
 
87
  with gr.Row():
88
- image_input = gr.Image(tool="sketch", type="numpy", label="Input Image")
 
89
  output_plot = gr.Plot(label="Output")
90
 
91
  with gr.Row():
@@ -95,13 +72,13 @@ with gr.Blocks() as demo:
95
  submit_button = gr.Button("Submit")
96
  submit_button.click(
97
  fn=infer,
98
- inputs=[image_input, dst_height, dst_width],
99
  outputs=output_plot
100
  )
101
 
102
  gr.Examples(
103
- examples=[[example_image_dict, "64", "128"]],
104
- inputs=[image_input, dst_height, dst_width],
105
  outputs=output_plot,
106
  fn=infer,
107
  cache_examples=True
 
6
  import matplotlib.pyplot as plt
7
  from scipy.cluster.vq import kmeans
8
 
9
+ # ... (previous functions remain the same)
 
 
 
 
 
 
10
 
11
+ def infer(image_input, mask_input, dst_height: str, dst_width: str):
12
+ image_in = image_input
13
+ mask_in = mask_input
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  torch_img = K.utils.image_to_tensor(image_in).float() / 255.0
15
 
16
  centroids = get_coordinates_from_mask(mask_in)
 
55
 
56
  example_mask = np.zeros((327, 600, 4), dtype=np.uint8)
57
  example_mask[:, :, 3] = 255
 
58
 
59
  with gr.Blocks() as demo:
60
  gr.Markdown("# Homography Warping")
61
  gr.Markdown(description)
62
 
63
  with gr.Row():
64
+ image_input = gr.Image(label="Input Image", type="numpy")
65
+ mask_input = gr.Image(label="Mask", source="canvas", shape=(600, 327), tool="sketch", interactive=True)
66
  output_plot = gr.Plot(label="Output")
67
 
68
  with gr.Row():
 
72
  submit_button = gr.Button("Submit")
73
  submit_button.click(
74
  fn=infer,
75
+ inputs=[image_input, mask_input, dst_height, dst_width],
76
  outputs=output_plot
77
  )
78
 
79
  gr.Examples(
80
+ examples=[["bruce.png", example_mask, "64", "128"]],
81
+ inputs=[image_input, mask_input, dst_height, dst_width],
82
  outputs=output_plot,
83
  fn=infer,
84
  cache_examples=True