SAM-arena / app.py
Jose Benitez
Streamlit app (#4)
2e64de9 unverified
raw
history blame
4.34 kB
import streamlit as st
from streamlit_drawable_canvas import st_canvas
from PIL import Image
import pandas as pd
import numpy as np
import torch
from utils.SAM import SAM_points_inference, FastSAM_points_inference, EfficientSAM_points_inference, EdgeSAM_points_inference
from utils.draw import draw_SAM_mask_point, draw_FastSAM_point, draw_EdgeSAM_point
from utils.tools import pil_to_bytes
def click(container_width,height,scale,radius_width,show_mask,im):
for each in ['color_change_point_box','input_masks_color_box']:
if each in st.session_state:st.session_state.pop(each)
canvas_result = st_canvas(
fill_color="rgba(255, 255, 0, 0.8)",
background_image = st.session_state['im'],
drawing_mode='point',
width = container_width,
height = height * scale,
point_display_radius = radius_width,
stroke_width=2,
update_streamlit=True,
key="click",)
if not show_mask:
im = Image.fromarray(im).convert("RGB")
rerun = False
if im != st.session_state['im']:
rerun = True
st.session_state['im'] = im
if rerun:
st.rerun()
elif canvas_result.json_data is not None:
df = pd.json_normalize(canvas_result.json_data["objects"])
if len(df) == 0:
st.session_state.clear()
if 'canvas_result' not in st.session_state:
st.session_state['canvas_result'] = len(df)
st.rerun()
elif len(df) != st.session_state['canvas_result']:
st.session_state['canvas_result'] = len(df)
st.rerun()
return
df["center_x"] = df["left"]
df["center_y"] = df["top"]
input_points = []
input_labels = []
for _, row in df.iterrows():
x, y = row["center_x"] + 5, row["center_y"]
x = int(x/scale)
y = int(y/scale)
input_points.append([x, y])
if row['fill'] == "rgba(0, 255, 0, 0.8)":
input_labels.append(1)
else:
input_labels.append(0)
col1, col2 = st.columns(2)
with col1:
# SAM inference
SAM_masks = SAM_points_inference(im, [input_points])
st.image(draw_SAM_mask_point(im, SAM_masks, input_points[0][0], input_points[0][1]))
st.success('SAM Inference completed!', icon="✅")
# EfficientSAM inference
EfficientSAM_masks = EfficientSAM_points_inference(im, input_points)
st.image(draw_SAM_mask_point(im, EfficientSAM_masks, input_points[0][0], input_points[0][1]))
st.success('EfficientSAM Inference completed!', icon="✅")
with col2:
# FastSAM inference
FastSAM_masks = FastSAM_points_inference(im, input_points, input_labels)
st.image(draw_FastSAM_point(FastSAM_masks))
st.success('FastSAM Inference completed!', icon="✅")
# EdgeSAM inference
EdgeSAM_masks = EdgeSAM_points_inference(im, input_points, [1])
st.image(draw_EdgeSAM_point(im, EdgeSAM_masks))
st.success('EdgeSAM Inference completed!', icon="✅")
def main():
print('init')
torch.cuda.empty_cache()
with st.sidebar:
im = st.file_uploader(label='Upload image',type=['png','jpg','tif'])
option = st.selectbox(
'Segmentation mode',
('Click', 'Box', 'Everything'))
show_mask = st.checkbox('Show mask',value = True)
radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1)
if im:
im = Image.open(im).convert("RGB")
if 'im' not in st.session_state:
st.session_state['im'] = im
width, height = im.size[:2]
im = np.array(im)
container_width = 700
scale = container_width/width
if option == 'Click':
click(container_width,
height,
scale,
radius_width,
show_mask,
im)
else:
st.session_state.clear()
if __name__ == '__main__':
main()