Harshithtd commited on
Commit
aa11617
·
verified ·
1 Parent(s): 02da3a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -15
app.py CHANGED
@@ -7,21 +7,31 @@ from gradio_image_prompter import ImagePrompter
7
  import spaces
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to(device)
 
 
11
  slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
12
 
 
 
 
 
 
13
  @spaces.GPU
14
- def sam_box_inference(image, x_min, y_min, x_max, y_max):
15
- inputs = slimsam_processor(
 
 
 
16
  Image.fromarray(image),
17
  input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
18
  return_tensors="pt"
19
  ).to(device)
20
 
21
  with torch.no_grad():
22
- outputs = slimsam_model(**inputs)
23
 
24
- mask = slimsam_processor.image_processor.post_process_masks(
25
  outputs.pred_masks.cpu(),
26
  inputs["original_sizes"].cpu(),
27
  inputs["reshaped_input_sizes"].cpu()
@@ -31,27 +41,98 @@ def sam_box_inference(image, x_min, y_min, x_max, y_max):
31
  print(mask.shape)
32
  return [(mask, "mask")]
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  def infer_box(prompts):
 
35
  image = prompts["image"]
36
  if image is None:
37
- gr.Error("Please upload an image and draw a box before submitting.")
38
  points = prompts["points"][0]
39
  if points is None:
40
- gr.Error("Please draw a box before submitting.")
41
  print(points)
42
- return [(image, sam_box_inference(image, points[0], points[1], points[3], points[4]))]
43
-
44
- with gr.Blocks(title="SlimSAM Box Prompt") as demo:
45
- gr.Markdown("# SlimSAM Box Prompt")
46
- gr.Markdown("In this demo, you can upload an image and draw a box for SlimSAM to process.")
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  with gr.Row():
49
  with gr.Column():
50
  im = ImagePrompter()
51
  btn = gr.Button("Submit")
52
  with gr.Column():
53
- output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
 
54
 
55
- btn.click(infer_box, inputs=im, outputs=[output_box_slimsam])
56
 
57
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import spaces
8
 
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-huge").to("cuda")
11
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
12
+ slimsam_model = SamModel.from_pretrained("nielsr/slimsam-50-uniform").to("cuda")
13
  slimsam_processor = SamProcessor.from_pretrained("nielsr/slimsam-50-uniform")
14
 
15
+ def get_processor_and_model(slim: bool):
16
+ if slim:
17
+ return slimsam_processor, slimsam_model
18
+ return sam_processor, sam_model
19
+
20
  @spaces.GPU
21
+ def sam_box_inference(image, x_min, y_min, x_max, y_max, *, slim=False):
22
+
23
+ processor, model = get_processor_and_model(slim)
24
+
25
+ inputs = processor(
26
  Image.fromarray(image),
27
  input_boxes=[[[[x_min, y_min, x_max, y_max]]]],
28
  return_tensors="pt"
29
  ).to(device)
30
 
31
  with torch.no_grad():
32
+ outputs = model(**inputs)
33
 
34
+ mask = processor.image_processor.post_process_masks(
35
  outputs.pred_masks.cpu(),
36
  inputs["original_sizes"].cpu(),
37
  inputs["reshaped_input_sizes"].cpu()
 
41
  print(mask.shape)
42
  return [(mask, "mask")]
43
 
44
+ @spaces.GPU
45
+ def sam_point_inference(image, x, y, *, slim=False):
46
+
47
+ processor, model = get_processor_and_model(slim)
48
+
49
+ inputs = processor(
50
+ image,
51
+ input_points=[[[x, y]]],
52
+ return_tensors="pt").to(device)
53
+
54
+ with torch.no_grad():
55
+ outputs = model(**inputs)
56
+
57
+ mask = processor.post_process_masks(
58
+ outputs.pred_masks.cpu(),
59
+ inputs["original_sizes"].cpu(),
60
+ inputs["reshaped_input_sizes"].cpu()
61
+ )[0][0][0].numpy()
62
+ mask = mask[np.newaxis, ...]
63
+ print(type(mask))
64
+ print(mask.shape)
65
+ return [(mask, "mask")]
66
+
67
+ def infer_point(img):
68
+ if img is None:
69
+ gr.Error("Please upload an image and select a point.")
70
+ if img["background"] is None:
71
+ gr.Error("Please upload an image and select a point.")
72
+ # background (original image) layers[0] ( point prompt) composite (total image)
73
+ image = img["background"].convert("RGB")
74
+ point_prompt = img["layers"][0]
75
+ total_image = img["composite"]
76
+ img_arr = np.array(point_prompt)
77
+ if not np.any(img_arr):
78
+ gr.Error("Please select a point on top of the image.")
79
+ else:
80
+ nonzero_indices = np.nonzero(img_arr)
81
+ img_arr = np.array(point_prompt)
82
+ nonzero_indices = np.nonzero(img_arr)
83
+ center_x = int(np.mean(nonzero_indices[1]))
84
+ center_y = int(np.mean(nonzero_indices[0]))
85
+ print("Point inference returned.")
86
+ return ((image, sam_point_inference(image, center_x, center_y, slim=True)),
87
+ (image, sam_point_inference(image, center_x, center_y)))
88
+
89
  def infer_box(prompts):
90
+ # background (original image) layers[0] ( point prompt) composite (total image)
91
  image = prompts["image"]
92
  if image is None:
93
+ gr.Error("Please upload an image and draw a box before submitting")
94
  points = prompts["points"][0]
95
  if points is None:
96
+ gr.Error("Please draw a box before submitting.")
97
  print(points)
 
 
 
 
 
98
 
99
+ # x_min = points[0] x_max = points[3] y_min = points[1] y_max = points[4]
100
+ return ((image, sam_box_inference(image, points[0], points[1], points[3], points[4], slim=True)),
101
+ (image, sam_box_inference(image, points[0], points[1], points[3], points[4])))
102
+ with gr.Blocks(title="SlimSAM") as demo:
103
+ gr.Markdown("# SlimSAM")
104
+ gr.Markdown("SlimSAM is the pruned-distilled version of SAM that is smaller.")
105
+ gr.Markdown("In this demo, you can compare SlimSAM and SAM outputs in point and box prompts.")
106
+
107
+ with gr.Tab("Box Prompt"):
108
+ with gr.Row():
109
+ with gr.Column(scale=1):
110
+ # Title
111
+ gr.Markdown("To try box prompting, simply upload and image and draw a box on it.")
112
  with gr.Row():
113
  with gr.Column():
114
  im = ImagePrompter()
115
  btn = gr.Button("Submit")
116
  with gr.Column():
117
+ output_box_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
118
+ output_box_sam = gr.AnnotatedImage(label="SAM Output")
119
 
 
120
 
121
+ btn.click(infer_box, inputs=im, outputs=[output_box_slimsam, output_box_sam])
122
+
123
+ with gr.Tab("Point Prompt"):
124
+ with gr.Row():
125
+ with gr.Column(scale=1):
126
+ # Title
127
+ gr.Markdown("To try point prompting, simply upload and image and leave a dot on it.")
128
+ with gr.Row():
129
+ with gr.Column():
130
+ im = gr.ImageEditor(
131
+ type="pil",
132
+ )
133
+ with gr.Column():
134
+ output_slimsam = gr.AnnotatedImage(label="SlimSAM Output")
135
+ output_sam = gr.AnnotatedImage(label="SAM Output")
136
+
137
+ im.change(infer_point, inputs=im, outputs=[output_slimsam, output_sam])
138
+ demo.launch(debug=True)