File size: 4,338 Bytes
2e64de9
 
aa36c04
 
2e64de9
aa36c04
 
 
2e64de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa36c04
2e64de9
 
aa36c04
2e64de9
 
aa36c04
2e64de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa36c04
2e64de9
 
 
 
 
aa36c04
2e64de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st
from streamlit_drawable_canvas import st_canvas

from PIL import Image
import pandas as pd
import numpy as np
import torch

from utils.SAM import SAM_points_inference, FastSAM_points_inference, EfficientSAM_points_inference, EdgeSAM_points_inference
from utils.draw import draw_SAM_mask_point, draw_FastSAM_point, draw_EdgeSAM_point
from utils.tools import pil_to_bytes

def click(container_width,height,scale,radius_width,show_mask,im):
    for each in ['color_change_point_box','input_masks_color_box']:
        if each in st.session_state:st.session_state.pop(each)
    canvas_result = st_canvas(
            fill_color="rgba(255, 255, 0, 0.8)",
            background_image = st.session_state['im'],
            drawing_mode='point',
            width = container_width,
            height = height * scale,
            point_display_radius = radius_width,
            stroke_width=2,
            update_streamlit=True,
            key="click",)
    if not show_mask:
        im = Image.fromarray(im).convert("RGB")
        rerun = False
        if im != st.session_state['im']:
            rerun = True
        st.session_state['im'] = im
        if rerun:
            st.rerun()
    elif canvas_result.json_data is not None:
        df = pd.json_normalize(canvas_result.json_data["objects"])
        if len(df) == 0:
            st.session_state.clear()
            if 'canvas_result' not in st.session_state:
                st.session_state['canvas_result'] = len(df)
                st.rerun()
            elif len(df) != st.session_state['canvas_result']:
                st.session_state['canvas_result'] = len(df)
                st.rerun()
            return
        
        df["center_x"] = df["left"]
        df["center_y"] = df["top"]
        
        input_points = []
        input_labels = []
        
        for _, row in df.iterrows():
            x, y = row["center_x"] + 5, row["center_y"]
            x = int(x/scale)
            y = int(y/scale)
            input_points.append([x, y])
            if row['fill'] == "rgba(0, 255, 0, 0.8)":
                input_labels.append(1)
            else:
                input_labels.append(0)
        
        col1, col2 = st.columns(2)
        
        with col1:
            # SAM inference
            SAM_masks = SAM_points_inference(im, [input_points])
            st.image(draw_SAM_mask_point(im, SAM_masks, input_points[0][0], input_points[0][1]))
            st.success('SAM Inference completed!', icon="✅")
            
            # EfficientSAM inference
            EfficientSAM_masks = EfficientSAM_points_inference(im, input_points)
            st.image(draw_SAM_mask_point(im, EfficientSAM_masks, input_points[0][0], input_points[0][1]))
            st.success('EfficientSAM Inference completed!', icon="✅")
        
        with col2:
            # FastSAM inference
            FastSAM_masks = FastSAM_points_inference(im, input_points, input_labels)
            st.image(draw_FastSAM_point(FastSAM_masks))
            st.success('FastSAM Inference completed!', icon="✅")
        
            # EdgeSAM inference
            EdgeSAM_masks = EdgeSAM_points_inference(im, input_points, [1])
            st.image(draw_EdgeSAM_point(im, EdgeSAM_masks))
            st.success('EdgeSAM Inference completed!', icon="✅")
        
        
def main():
    print('init')    
    torch.cuda.empty_cache()
    
    with st.sidebar:
        im = st.file_uploader(label='Upload image',type=['png','jpg','tif'])
        option = st.selectbox(
            'Segmentation mode',
            ('Click', 'Box', 'Everything'))
    
        show_mask = st.checkbox('Show mask',value = True)
        radius_width = st.slider('Radius/Width for Click/Box',0,20,5,1)
        
    if im:
        im = Image.open(im).convert("RGB")
        if 'im' not in st.session_state:
            st.session_state['im'] = im
        width, height   = im.size[:2]
        im              = np.array(im)
        container_width = 700
        scale           = container_width/width
        if option == 'Click':
            click(container_width,
                  height,
                  scale,
                  radius_width,
                  show_mask,
                  im)
    else:
        st.session_state.clear()

if __name__ == '__main__':
    main()