ciCic commited on
Commit
18bac9f
·
1 Parent(s): 04afeac
Files changed (2) hide show
  1. app.py +127 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import SamModel, SamProcessor
6
+ from gradio_image_prompter import ImagePrompter
7
+
8
+ device = 'cpu'
9
+ model_id = "nielsr/slimsam-50-uniform"
10
+
11
+ slim_sam_model = SamModel.from_pretrained(model_id).to(device)
12
+ slim_sam_processor = SamProcessor.from_pretrained(model_id)
13
+
14
+
15
+ def sam_box_inference(image, x_min, y_min, x_max, y_max):
16
+ processor, model = slim_sam_processor, slim_sam_model
17
+
18
+ inputs = processor(
19
+ Image.fromarray(image),
20
+ input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
21
+ return_tensors="pt"
22
+ ).to(device)
23
+
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+
27
+ mask = processor.image_processor.post_process_masks(
28
+ outputs.pred_masks.cpu(),
29
+ inputs["original_sizes"].cpu(),
30
+ inputs["reshaped_input_sizes"].cpu()
31
+ )[0][0][0].numpy()
32
+ mask = mask[np.newaxis, ...]
33
+ print(mask)
34
+ print(mask.shape)
35
+ return [(mask, "mask")]
36
+
37
+
38
+ def sam_point_inference(image, x, y):
39
+ processor, model = slim_sam_processor, slim_sam_model
40
+
41
+ inputs = processor(
42
+ image,
43
+ input_points=[[[x, y]]],
44
+ return_tensors="pt").to(device)
45
+
46
+ with torch.no_grad():
47
+ outputs = model(**inputs)
48
+
49
+ mask = processor.post_process_masks(
50
+ outputs.pred_masks.cpu(),
51
+ inputs["original_sizes"].cpu(),
52
+ inputs["reshaped_input_sizes"].cpu()
53
+ )[0][0][0].numpy()
54
+ mask = mask[np.newaxis, ...]
55
+ print(type(mask))
56
+ print(mask.shape)
57
+ return [(mask, "mask")]
58
+
59
+
60
+ def infer_point(img):
61
+ if img is None:
62
+ gr.Error("Please upload an image and select a point.")
63
+ if img["background"] is None:
64
+ gr.Error("Please upload an image and select a point.")
65
+
66
+ image = img["background"].convert("RGB")
67
+ point_prompt = img["layers"][0]
68
+ total_image = img["composite"]
69
+ img_arr = np.array(point_prompt)
70
+ if not np.any(img_arr):
71
+ gr.Error("Please select a point on top of the image.")
72
+ else:
73
+ nonzero_indices = np.nonzero(img_arr)
74
+ img_arr = np.array(point_prompt)
75
+ nonzero_indices = np.nonzero(img_arr)
76
+ center_x = int(np.mean(nonzero_indices[1]))
77
+ center_y = int(np.mean(nonzero_indices[0]))
78
+ print("Point inference returned.")
79
+ return (image, sam_point_inference(image, center_x, center_y))
80
+
81
+
82
+ def infer_box(prompts):
83
+ image = prompts["image"]
84
+ if image is None:
85
+ gr.Error("Please upload an image and draw a box before submitting")
86
+ points = prompts["points"][0]
87
+ if points is None:
88
+ gr.Error("Please draw a box before submitting.")
89
+ print(points)
90
+
91
+ return (image, sam_box_inference(image, points[0], points[1], points[3], points[4]))
92
+
93
+
94
+ if __name__ == '__main__':
95
+ with gr.Blocks(title="SlimSAM") as demo:
96
+ gr.Markdown("# SlimSAM")
97
+ gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.")
98
+ gr.Markdown("In this demo, you can compare SlimSAM outputs in point and box prompts.")
99
+
100
+ with gr.Tab("Box Prompt"):
101
+ with gr.Row():
102
+ with gr.Column(scale=1):
103
+ gr.Markdown("To try box prompting, simply upload and image and draw a box on it.")
104
+ with gr.Row():
105
+ with gr.Column():
106
+ im = ImagePrompter()
107
+ btn = gr.Button("Submit")
108
+ with gr.Column():
109
+ output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
110
+
111
+ btn.click(infer_box, inputs=im, outputs=[output_box_slimsam])
112
+
113
+ with gr.Tab("Point Prompt"):
114
+ with gr.Row():
115
+ with gr.Column(scale=1):
116
+ gr.Markdown("To try point prompting, simply upload and image and leave a dot on it.")
117
+ with gr.Row():
118
+ with gr.Column():
119
+ im = gr.ImageEditor(
120
+ type="pil",
121
+ )
122
+ with gr.Column():
123
+ output_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
124
+
125
+ im.change(infer_point, inputs=im, outputs=[output_slimsam])
126
+
127
+ demo.launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio-image-prompter
2
+ transformers
3
+ torch
4
+ jupyter