AAAAAAAyq
commited on
Commit
•
4b45202
1
Parent(s):
e03ed2b
Better points mode & Fix the Contours button bug
Browse files- app_gradio.py +4 -4
- utils/tools.py +4 -3
app_gradio.py
CHANGED
@@ -221,7 +221,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
|
|
221 |
input_size_slider.render()
|
222 |
|
223 |
with gr.Row():
|
224 |
-
|
225 |
|
226 |
with gr.Column():
|
227 |
segment_btn_e = gr.Button("Segment Everything", variant='primary')
|
@@ -298,7 +298,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
|
|
298 |
info='Our model was trained on a size of 1024')
|
299 |
with gr.Row():
|
300 |
with gr.Column():
|
301 |
-
|
302 |
text_box = gr.Textbox(label="text prompt", value="a black dog")
|
303 |
|
304 |
with gr.Column():
|
@@ -334,7 +334,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
|
|
334 |
iou_threshold,
|
335 |
conf_threshold,
|
336 |
mor_check,
|
337 |
-
|
338 |
retina_check,
|
339 |
],
|
340 |
outputs=segm_img_e)
|
@@ -350,7 +350,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
|
|
350 |
iou_threshold,
|
351 |
conf_threshold,
|
352 |
mor_check,
|
353 |
-
|
354 |
retina_check,
|
355 |
text_box,
|
356 |
],
|
|
|
221 |
input_size_slider.render()
|
222 |
|
223 |
with gr.Row():
|
224 |
+
contour_check_e = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
|
225 |
|
226 |
with gr.Column():
|
227 |
segment_btn_e = gr.Button("Segment Everything", variant='primary')
|
|
|
298 |
info='Our model was trained on a size of 1024')
|
299 |
with gr.Row():
|
300 |
with gr.Column():
|
301 |
+
contour_check_t = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
|
302 |
text_box = gr.Textbox(label="text prompt", value="a black dog")
|
303 |
|
304 |
with gr.Column():
|
|
|
334 |
iou_threshold,
|
335 |
conf_threshold,
|
336 |
mor_check,
|
337 |
+
contour_check_e,
|
338 |
retina_check,
|
339 |
],
|
340 |
outputs=segm_img_e)
|
|
|
350 |
iou_threshold,
|
351 |
conf_threshold,
|
352 |
mor_check,
|
353 |
+
contour_check_t,
|
354 |
retina_check,
|
355 |
text_box,
|
356 |
],
|
utils/tools.py
CHANGED
@@ -400,16 +400,17 @@ def point_prompt(masks, points, point_label, target_height, target_width): # nu
|
|
400 |
for point in points
|
401 |
]
|
402 |
onemask = np.zeros((h, w))
|
|
|
403 |
for i, annotation in enumerate(masks):
|
404 |
if type(annotation) == dict:
|
405 |
-
mask = annotation[
|
406 |
else:
|
407 |
mask = annotation
|
408 |
for i, point in enumerate(points):
|
409 |
if mask[point[1], point[0]] == 1 and point_label[i] == 1:
|
410 |
-
onemask
|
411 |
if mask[point[1], point[0]] == 1 and point_label[i] == 0:
|
412 |
-
onemask
|
413 |
onemask = onemask >= 1
|
414 |
return onemask, 0
|
415 |
|
|
|
400 |
for point in points
|
401 |
]
|
402 |
onemask = np.zeros((h, w))
|
403 |
+
masks = sorted(masks, key=lambda x: x['area'], reverse=True)
|
404 |
for i, annotation in enumerate(masks):
|
405 |
if type(annotation) == dict:
|
406 |
+
mask = annotation['segmentation']
|
407 |
else:
|
408 |
mask = annotation
|
409 |
for i, point in enumerate(points):
|
410 |
if mask[point[1], point[0]] == 1 and point_label[i] == 1:
|
411 |
+
onemask[mask] = 1
|
412 |
if mask[point[1], point[0]] == 1 and point_label[i] == 0:
|
413 |
+
onemask[mask] = 0
|
414 |
onemask = onemask >= 1
|
415 |
return onemask, 0
|
416 |
|