yunyangx commited on
Commit
47887bb
1 Parent(s): 070c43b

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -163
app.py CHANGED
@@ -76,97 +76,88 @@ def segment_with_boxs(
76
  use_retina=True,
77
  mask_random_color=True,
78
  ):
79
- try:
80
- global global_points
81
- global global_point_label
82
- if len(global_points) < 2:
83
- return seg_image
84
- print("Original Image : ", image.size)
85
-
86
- input_size = int(input_size)
87
- w, h = image.size
88
- scale = input_size / max(w, h)
89
- new_w = int(w * scale)
90
- new_h = int(h * scale)
91
- image = image.resize((new_w, new_h))
92
-
93
- print("Scaled Image : ", image.size)
94
- print("Scale : ", scale)
95
-
96
- scaled_points = np.array(
97
- [[int(x * scale) for x in point] for point in global_points]
98
- )
99
- scaled_points = scaled_points[:2]
100
- scaled_point_label = np.array(global_point_label)[:2]
101
-
102
- print(scaled_points, scaled_points is not None)
103
- print(scaled_point_label, scaled_point_label is not None)
104
 
105
- if scaled_points.size == 0 and scaled_point_label.size == 0:
106
- print("No points selected")
107
- return image
108
 
109
- nd_image = np.array(image)
110
- img_tensor = ToTensor()(nd_image)
 
111
 
112
- print(img_tensor.shape)
113
- pts_sampled = torch.reshape(torch.tensor(scaled_points), [1, 1, -1, 2])
114
- pts_sampled = pts_sampled[:, :, :2, :]
115
- pts_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
116
 
117
- predicted_logits, predicted_iou = model(
118
- img_tensor[None, ...].to(device),
119
- pts_sampled.to(device),
120
- pts_labels.to(device),
121
- )
122
- predicted_logits = predicted_logits.cpu()
123
- all_masks = torch.ge(
124
- torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5
125
- ).numpy()
126
- predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
127
-
128
- max_predicted_iou = -1
129
- selected_mask_using_predicted_iou = None
130
- selected_predicted_iou = None
131
-
132
- for m in range(all_masks.shape[0]):
133
- curr_predicted_iou = predicted_iou[m]
134
- if (
135
- curr_predicted_iou > max_predicted_iou
136
- or selected_mask_using_predicted_iou is None
137
- ):
138
- max_predicted_iou = curr_predicted_iou
139
- selected_mask_using_predicted_iou = all_masks[m : m + 1]
140
- selected_predicted_iou = predicted_iou[m : m + 1]
141
-
142
- results = format_results(
143
- selected_mask_using_predicted_iou,
144
- selected_predicted_iou,
145
- predicted_logits,
146
- 0,
147
- )
148
 
149
- annotations = results[0]["segmentation"]
150
- annotations = np.array([annotations])
151
- print(scaled_points.shape)
152
- fig = fast_process(
153
- annotations=annotations,
154
- image=image,
155
- device=device,
156
- scale=(1024 // input_size),
157
- better_quality=better_quality,
158
- mask_random_color=mask_random_color,
159
- use_retina=use_retina,
160
- bbox=scaled_points.reshape([4]),
161
- withContours=withContours,
162
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
- global_points = []
165
- global_point_label = []
166
- # return fig, None
167
- return fig
168
- except:
169
- return image
170
 
171
 
172
  def segment_with_points(
@@ -177,82 +168,77 @@ def segment_with_points(
177
  use_retina=True,
178
  mask_random_color=True,
179
  ):
180
- try:
181
- global global_points
182
- global global_point_label
183
-
184
- print("Original Image : ", image.size)
185
-
186
- input_size = int(input_size)
187
- w, h = image.size
188
- scale = input_size / max(w, h)
189
- new_w = int(w * scale)
190
- new_h = int(h * scale)
191
- image = image.resize((new_w, new_h))
192
-
193
- print("Scaled Image : ", image.size)
194
- print("Scale : ", scale)
195
-
196
- if global_points is None:
197
- return image
198
- if len(global_points) < 1:
199
- return image
200
- scaled_points = np.array(
201
- [[int(x * scale) for x in point] for point in global_points]
202
- )
203
- scaled_point_label = np.array(global_point_label)
204
 
205
- print(scaled_points, scaled_points is not None)
206
- print(scaled_point_label, scaled_point_label is not None)
207
 
208
- if scaled_points.size == 0 and scaled_point_label.size == 0:
209
- print("No points selected")
210
- return image
 
 
 
211
 
212
- nd_image = np.array(image)
213
- img_tensor = ToTensor()(nd_image)
214
 
215
- print(img_tensor.shape)
216
- pts_sampled = torch.reshape(torch.tensor(scaled_points), [1, 1, -1, 2])
217
- pts_labels = torch.reshape(torch.tensor(global_point_label), [1, 1, -1])
 
 
 
 
 
218
 
219
- predicted_logits, predicted_iou = model(
220
- img_tensor[None, ...].to(device),
221
- pts_sampled.to(device),
222
- pts_labels.to(device),
223
- )
224
- predicted_logits = predicted_logits.cpu()
225
- all_masks = torch.ge(
226
- torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5
227
- ).numpy()
228
- predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
229
 
230
- results = format_results(all_masks, predicted_iou, predicted_logits, 0)
 
 
231
 
232
- annotations, _ = point_prompt(
233
- results, scaled_points, scaled_point_label, new_h, new_w
234
- )
235
- annotations = np.array([annotations])
236
-
237
- fig = fast_process(
238
- annotations=annotations,
239
- image=image,
240
- device=device,
241
- scale=(1024 // input_size),
242
- better_quality=better_quality,
243
- mask_random_color=mask_random_color,
244
- points=scaled_points,
245
- bbox=None,
246
- use_retina=use_retina,
247
- withContours=withContours,
248
- )
249
 
250
- global_points = []
251
- global_point_label = []
252
- # return fig, None
253
- return fig
254
- except:
255
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
256
 
257
 
258
  def get_points_with_draw(image, cond_image, evt: gr.SelectData):
@@ -276,16 +262,12 @@ def get_points_with_draw(image, cond_image, evt: gr.SelectData):
276
  draw = ImageDraw.Draw(image)
277
 
278
  draw.ellipse(
279
- [
280
- (x - point_radius, y - point_radius),
281
- (x + point_radius, y + point_radius),
282
- ],
283
  fill=point_color,
284
  )
285
 
286
  return image
287
 
288
-
289
  def get_points_with_draw_(image, cond_image, evt: gr.SelectData):
290
  global global_points
291
  global global_point_label
@@ -309,10 +291,7 @@ def get_points_with_draw_(image, cond_image, evt: gr.SelectData):
309
  draw = ImageDraw.Draw(image)
310
 
311
  draw.ellipse(
312
- [
313
- (x - point_radius, y - point_radius),
314
- (x + point_radius, y + point_radius),
315
- ],
316
  fill=point_color,
317
  )
318
 
@@ -411,6 +390,7 @@ with gr.Blocks(css=css, title="Efficient SAM") as demo:
411
  gr.Examples(
412
  examples=examples,
413
  inputs=[cond_img_b],
 
414
  examples_per_page=4,
415
  )
416
 
@@ -422,7 +402,9 @@ with gr.Blocks(css=css, title="Efficient SAM") as demo:
422
 
423
  cond_img_b.select(get_points_with_draw_, [segm_img_b, cond_img_b], segm_img_b)
424
 
425
- segment_btn_p.click(segment_with_points, inputs=[cond_img_p], outputs=segm_img_p)
 
 
426
 
427
  segment_btn_b.click(
428
  segment_with_boxs, inputs=[cond_img_b, segm_img_b], outputs=segm_img_b
 
76
  use_retina=True,
77
  mask_random_color=True,
78
  ):
79
+ global global_points
80
+ global global_point_label
81
+ if len(global_points) < 2:
82
+ return seg_image
83
+ print("Original Image : ", image.size)
84
+
85
+ input_size = int(input_size)
86
+ w, h = image.size
87
+ scale = input_size / max(w, h)
88
+ new_w = int(w * scale)
89
+ new_h = int(h * scale)
90
+ image = image.resize((new_w, new_h))
91
+
92
+ print("Scaled Image : ", image.size)
93
+ print("Scale : ", scale)
94
+
95
+ scaled_points = np.array(
96
+ [[int(x * scale) for x in point] for point in global_points]
97
+ )
98
+ scaled_points = scaled_points[:2]
99
+ scaled_point_label = np.array(global_point_label)[:2]
 
 
 
 
100
 
101
+ print(scaled_points, scaled_points is not None)
102
+ print(scaled_point_label, scaled_point_label is not None)
 
103
 
104
+ if scaled_points.size == 0 and scaled_point_label.size == 0:
105
+ print("No points selected")
106
+ return image
107
 
108
+ nd_image = np.array(image)
109
+ img_tensor = ToTensor()(nd_image)
 
 
110
 
111
+ print(img_tensor.shape)
112
+ pts_sampled = torch.reshape(torch.tensor(scaled_points), [1, 1, -1, 2])
113
+ pts_sampled = pts_sampled[:, :, :2, :]
114
+ pts_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ predicted_logits, predicted_iou = model(
117
+ img_tensor[None, ...].to(device),
118
+ pts_sampled.to(device),
119
+ pts_labels.to(device),
120
+ )
121
+ predicted_logits = predicted_logits.cpu()
122
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
123
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
124
+
125
+
126
+ max_predicted_iou = -1
127
+ selected_mask_using_predicted_iou = None
128
+ selected_predicted_iou = None
129
+
130
+ for m in range(all_masks.shape[0]):
131
+ curr_predicted_iou = predicted_iou[m]
132
+ if (
133
+ curr_predicted_iou > max_predicted_iou
134
+ or selected_mask_using_predicted_iou is None
135
+ ):
136
+ max_predicted_iou = curr_predicted_iou
137
+ selected_mask_using_predicted_iou = all_masks[m:m+1]
138
+ selected_predicted_iou = predicted_iou[m:m+1]
139
+
140
+ results = format_results(selected_mask_using_predicted_iou, selected_predicted_iou, predicted_logits, 0)
141
+
142
+ annotations = results[0]["segmentation"]
143
+ annotations = np.array([annotations])
144
+ print(scaled_points.shape)
145
+ fig = fast_process(
146
+ annotations=annotations,
147
+ image=image,
148
+ device=device,
149
+ scale=(1024 // input_size),
150
+ better_quality=better_quality,
151
+ mask_random_color=mask_random_color,
152
+ use_retina=use_retina,
153
+ bbox = scaled_points.reshape([4]),
154
+ withContours=withContours,
155
+ )
156
 
157
+ global_points = []
158
+ global_point_label = []
159
+ # return fig, None
160
+ return fig
 
 
161
 
162
 
163
  def segment_with_points(
 
168
  use_retina=True,
169
  mask_random_color=True,
170
  ):
171
+ global global_points
172
+ global global_point_label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
 
174
+ print("Original Image : ", image.size)
 
175
 
176
+ input_size = int(input_size)
177
+ w, h = image.size
178
+ scale = input_size / max(w, h)
179
+ new_w = int(w * scale)
180
+ new_h = int(h * scale)
181
+ image = image.resize((new_w, new_h))
182
 
183
+ print("Scaled Image : ", image.size)
184
+ print("Scale : ", scale)
185
 
186
+ if global_points is None:
187
+ return image
188
+ if len(global_points) < 1:
189
+ return image
190
+ scaled_points = np.array(
191
+ [[int(x * scale) for x in point] for point in global_points]
192
+ )
193
+ scaled_point_label = np.array(global_point_label)
194
 
195
+ print(scaled_points, scaled_points is not None)
196
+ print(scaled_point_label, scaled_point_label is not None)
 
 
 
 
 
 
 
 
197
 
198
+ if scaled_points.size == 0 and scaled_point_label.size == 0:
199
+ print("No points selected")
200
+ return image
201
 
202
+ nd_image = np.array(image)
203
+ img_tensor = ToTensor()(nd_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
+ print(img_tensor.shape)
206
+ pts_sampled = torch.reshape(torch.tensor(scaled_points), [1, 1, -1, 2])
207
+ pts_labels = torch.reshape(torch.tensor(global_point_label), [1, 1, -1])
208
+
209
+ predicted_logits, predicted_iou = model(
210
+ img_tensor[None, ...].to(device),
211
+ pts_sampled.to(device),
212
+ pts_labels.to(device),
213
+ )
214
+ predicted_logits = predicted_logits.cpu()
215
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
216
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
217
+
218
+ results = format_results(all_masks, predicted_iou, predicted_logits, 0)
219
+
220
+ annotations, _ = point_prompt(
221
+ results, scaled_points, scaled_point_label, new_h, new_w
222
+ )
223
+ annotations = np.array([annotations])
224
+
225
+ fig = fast_process(
226
+ annotations=annotations,
227
+ image=image,
228
+ device=device,
229
+ scale=(1024 // input_size),
230
+ better_quality=better_quality,
231
+ mask_random_color=mask_random_color,
232
+ points = scaled_points,
233
+ bbox=None,
234
+ use_retina=use_retina,
235
+ withContours=withContours,
236
+ )
237
+
238
+ global_points = []
239
+ global_point_label = []
240
+ # return fig, None
241
+ return fig
242
 
243
 
244
  def get_points_with_draw(image, cond_image, evt: gr.SelectData):
 
262
  draw = ImageDraw.Draw(image)
263
 
264
  draw.ellipse(
265
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
 
 
 
266
  fill=point_color,
267
  )
268
 
269
  return image
270
 
 
271
  def get_points_with_draw_(image, cond_image, evt: gr.SelectData):
272
  global global_points
273
  global global_point_label
 
291
  draw = ImageDraw.Draw(image)
292
 
293
  draw.ellipse(
294
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
 
 
 
295
  fill=point_color,
296
  )
297
 
 
390
  gr.Examples(
391
  examples=examples,
392
  inputs=[cond_img_b],
393
+
394
  examples_per_page=4,
395
  )
396
 
 
402
 
403
  cond_img_b.select(get_points_with_draw_, [segm_img_b, cond_img_b], segm_img_b)
404
 
405
+ segment_btn_p.click(
406
+ segment_with_points, inputs=[cond_img_p], outputs=segm_img_p
407
+ )
408
 
409
  segment_btn_b.click(
410
  segment_with_boxs, inputs=[cond_img_b, segm_img_b], outputs=segm_img_b