Spaces:
Sleeping
Sleeping
import streamlit as st | |
from streamlit_drawable_canvas import st_canvas | |
from PIL import Image | |
import pandas as pd | |
import numpy as np | |
import torch | |
from utils.SAM import SAM_points_inference, FastSAM_points_inference, EfficientSAM_points_inference, EdgeSAM_points_inference | |
from utils.draw import draw_SAM_mask_point, draw_FastSAM_point, draw_EdgeSAM_point | |
from utils.tools import pil_to_bytes | |
def click(container_width,height,scale,radius_width,show_mask,im): | |
for each in ['color_change_point_box','input_masks_color_box']: | |
if each in st.session_state:st.session_state.pop(each) | |
canvas_result = st_canvas( | |
fill_color="rgba(255, 255, 0, 0.8)", | |
background_image = st.session_state['im'], | |
drawing_mode='point', | |
width = container_width, | |
height = height * scale, | |
point_display_radius = radius_width, | |
stroke_width=2, | |
update_streamlit=True, | |
key="click",) | |
if not show_mask: | |
im = Image.fromarray(im).convert("RGB") | |
rerun = False | |
if im != st.session_state['im']: | |
rerun = True | |
st.session_state['im'] = im | |
if rerun: | |
st.rerun() | |
elif canvas_result.json_data is not None: | |
df = pd.json_normalize(canvas_result.json_data["objects"]) | |
if len(df) == 0: | |
st.session_state.clear() | |
if 'canvas_result' not in st.session_state: | |
st.session_state['canvas_result'] = len(df) | |
st.rerun() | |
elif len(df) != st.session_state['canvas_result']: | |
st.session_state['canvas_result'] = len(df) | |
st.rerun() | |
return | |
df["center_x"] = df["left"] | |
df["center_y"] = df["top"] | |
input_points = [] | |
input_labels = [] | |
for _, row in df.iterrows(): | |
x, y = row["center_x"] + 5, row["center_y"] | |
x = int(x/scale) | |
y = int(y/scale) | |
input_points.append([x, y]) | |
if row['fill'] == "rgba(0, 255, 0, 0.8)": | |
input_labels.append(1) | |
else: | |
input_labels.append(0) | |
col1, col2 = st.columns(2) | |
with col1: | |
# SAM inference | |
SAM_masks = SAM_points_inference(im, [input_points]) | |
st.image(draw_SAM_mask_point(im, SAM_masks, input_points[0][0], input_points[0][1])) | |
st.success('SAM Inference completed!', icon="✅") | |
# EfficientSAM inference | |
EfficientSAM_masks = EfficientSAM_points_inference(im, input_points) | |
st.image(draw_SAM_mask_point(im, EfficientSAM_masks, input_points[0][0], input_points[0][1])) | |
st.success('EfficientSAM Inference completed!', icon="✅") | |
with col2: | |
# FastSAM inference | |
FastSAM_masks = FastSAM_points_inference(im, input_points, input_labels) | |
st.image(draw_FastSAM_point(FastSAM_masks)) | |
st.success('FastSAM Inference completed!', icon="✅") | |
# EdgeSAM inference | |
EdgeSAM_masks = EdgeSAM_points_inference(im, input_points, [1]) | |
st.image(draw_EdgeSAM_point(im, EdgeSAM_masks)) | |
st.success('EdgeSAM Inference completed!', icon="✅") | |
def main(): | |
print('init') | |
torch.cuda.empty_cache() | |
with st.sidebar: | |
im = st.file_uploader(label='Upload image',type=['png','jpg','tif']) | |
option = st.selectbox( | |
'Segmentation mode', | |
('Click', 'Box', 'Everything')) | |
show_mask = st.checkbox('Show mask',value = True) | |
radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1) | |
if im: | |
im = Image.open(im).convert("RGB") | |
if 'im' not in st.session_state: | |
st.session_state['im'] = im | |
width, height = im.size[:2] | |
im = np.array(im) | |
container_width = 700 | |
scale = container_width/width | |
if option == 'Click': | |
click(container_width, | |
height, | |
scale, | |
radius_width, | |
show_mask, | |
im) | |
else: | |
st.session_state.clear() | |
if __name__ == '__main__': | |
main() |