Spaces:
Sleeping
Sleeping
File size: 4,338 Bytes
2e64de9 aa36c04 2e64de9 aa36c04 2e64de9 aa36c04 2e64de9 aa36c04 2e64de9 aa36c04 2e64de9 aa36c04 2e64de9 aa36c04 2e64de9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import pandas as pd
import numpy as np
import torch
from utils.SAM import SAM_points_inference, FastSAM_points_inference, EfficientSAM_points_inference, EdgeSAM_points_inference
from utils.draw import draw_SAM_mask_point, draw_FastSAM_point, draw_EdgeSAM_point
from utils.tools import pil_to_bytes
def click(container_width,height,scale,radius_width,show_mask,im):
for each in ['color_change_point_box','input_masks_color_box']:
if each in st.session_state:st.session_state.pop(each)
canvas_result = st_canvas(
fill_color="rgba(255, 255, 0, 0.8)",
background_image = st.session_state['im'],
drawing_mode='point',
width = container_width,
height = height * scale,
point_display_radius = radius_width,
stroke_width=2,
update_streamlit=True,
key="click",)
if not show_mask:
im = Image.fromarray(im).convert("RGB")
rerun = False
if im != st.session_state['im']:
rerun = True
st.session_state['im'] = im
if rerun:
st.rerun()
elif canvas_result.json_data is not None:
df = pd.json_normalize(canvas_result.json_data["objects"])
if len(df) == 0:
st.session_state.clear()
if 'canvas_result' not in st.session_state:
st.session_state['canvas_result'] = len(df)
st.rerun()
elif len(df) != st.session_state['canvas_result']:
st.session_state['canvas_result'] = len(df)
st.rerun()
return
df["center_x"] = df["left"]
df["center_y"] = df["top"]
input_points = []
input_labels = []
for _, row in df.iterrows():
x, y = row["center_x"] + 5, row["center_y"]
x = int(x/scale)
y = int(y/scale)
input_points.append([x, y])
if row['fill'] == "rgba(0, 255, 0, 0.8)":
input_labels.append(1)
else:
input_labels.append(0)
col1, col2 = st.columns(2)
with col1:
# SAM inference
SAM_masks = SAM_points_inference(im, [input_points])
st.image(draw_SAM_mask_point(im, SAM_masks, input_points[0][0], input_points[0][1]))
st.success('SAM Inference completed!', icon="✅")
# EfficientSAM inference
EfficientSAM_masks = EfficientSAM_points_inference(im, input_points)
st.image(draw_SAM_mask_point(im, EfficientSAM_masks, input_points[0][0], input_points[0][1]))
st.success('EfficientSAM Inference completed!', icon="✅")
with col2:
# FastSAM inference
FastSAM_masks = FastSAM_points_inference(im, input_points, input_labels)
st.image(draw_FastSAM_point(FastSAM_masks))
st.success('FastSAM Inference completed!', icon="✅")
# EdgeSAM inference
EdgeSAM_masks = EdgeSAM_points_inference(im, input_points, [1])
st.image(draw_EdgeSAM_point(im, EdgeSAM_masks))
st.success('EdgeSAM Inference completed!', icon="✅")
def main():
print('init')
torch.cuda.empty_cache()
with st.sidebar:
im = st.file_uploader(label='Upload image',type=['png','jpg','tif'])
option = st.selectbox(
'Segmentation mode',
('Click', 'Box', 'Everything'))
show_mask = st.checkbox('Show mask',value = True)
radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1)
if im:
im = Image.open(im).convert("RGB")
if 'im' not in st.session_state:
st.session_state['im'] = im
width, height = im.size[:2]
im = np.array(im)
container_width = 700
scale = container_width/width
if option == 'Click':
click(container_width,
height,
scale,
radius_width,
show_mask,
im)
else:
st.session_state.clear()
if __name__ == '__main__':
main() |