import time
import sys
import streamlit as st
import string
import os
from io import StringIO
import pdb
import json
import torch
import requests
import socket
from streamlit_image_select import image_select
use_case = {"1":"Image background removal - (upload any picture and remove background)","2":"Masking foreground for downstream inpainting task"}
mask_types = {
"rgba - makes background white":"rgba",
"green - makes the background green":"green",
"blur - blurs background":"blur",
"map - makes the foreground white and rest black ":"map"
}
APP_NAME = "hf/salient_object_detection"
INFO_URL = "https://www.taskswithcode.com/stats/"
TMP_DIR="tmp_dir"
TMP_SEED = 1
def get_views(action):
ret_val = 0
#return "{:,}".format(ret_val)
hostname = socket.gethostname()
ip_address = socket.gethostbyname(hostname)
if ("view_count" not in st.session_state):
try:
app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
res = requests.post(INFO_URL, json = app_info).json()
print(res)
data = res["count"]
except:
data = 0
ret_val = data
st.session_state["view_count"] = data
else:
ret_val = st.session_state["view_count"]
if (action != "init"):
app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address}
res = requests.post(INFO_URL, json = app_info).json()
return "{:,}".format(ret_val)
def construct_model_info_for_display(model_names):
options_arr = []
#markdown_str = f"
Models evaluated ({len(model_names)})
"
markdown_str = f"
Model evaluated
"
markdown_str += f"
"
for node in model_names:
options_arr .append(node["name"])
if (node["mark"] == "True"):
markdown_str += f""
if ("Note" in node):
markdown_str += f""
markdown_str += "
"
markdown_str += "Note:
• Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not cached
"
markdown_str += ""
return options_arr,markdown_str
def init_page():
st.set_page_config(page_title='TWC - State-of-the-art model salient object detection (visually dominant objects in an image)', page_icon="logo.jpg", layout='centered', initial_sidebar_state='auto',
menu_items={
'About': 'This app was created by taskswithcode. http://taskswithcode.com'
})
col,pad = st.columns([85,15])
with col:
st.image("long_form_logo_with_icon.png")
def run_test(config,input_file_name,display_area,uploaded_file,mask_type):
global TMP_SEED
display_area.text("Processing request...")
try:
if (uploaded_file is None):
file_data = open(input_file_name, "rb")
r = requests.post(config["SERVER_ADDRESS"], data={"mask":mask_type}, files={"test":file_data})
else:
file_data = uploaded_file.read()
file_name = f"{TMP_DIR}/{TMP_SEED}_{str(time.time()).replace('.','_')}_{uploaded_file.name}"
TMP_SEED += 1
with open(file_name,"wb") as fp:
fp.write(file_data)
file_data = open(file_name, "rb")
r = requests.post(config["SERVER_ADDRESS"], data={"mask":mask_type}, files={"test":file_data})
os.remove(file_name)
print("Servers response:",r.status_code,len(r.content))
if (r.status_code == 200):
size = "{:,}".format(len(r.content))
return {"response":r.content,"size":size}
else:
return {"error":f"API request failed {r.status_code}"}
except Exception as e:
st.error("Some error occurred during prediction" + str(e))
st.stop()
return {"error":f"Exception in performing salient object detection: {str(e)}"}
return {}
def display_results(results,response_info,mask):
main_sent = f"{response_info}
"
body_sent = []
download_data = {}
main_sent = main_sent + "\n" + '\n'.join(body_sent)
st.markdown(main_sent,unsafe_allow_html=True)
st.image(results["response"], caption=f'Output of salient object detection with mask: {mask}')
st.session_state["download_ready"] = results["response"]
get_views("submit")
def init_session():
print("Init session")
init_page()
st.session_state["model_name"] = "insprynet"
st.session_state["download_ready"] = None
st.session_state["model_name"] = "ss_test"
st.session_state["file_name"] = "default"
st.session_state["mask_type"] = "rgba"
def app_main(app_mode,example_files,model_name_files,config_file):
init_session()
with open(example_files) as fp:
example_file_names = json.load(fp)
with open(model_name_files) as fp:
model_names = json.load(fp)
with open(config_file) as fp:
config = json.load(fp)
curr_use_case = use_case[app_mode].split(".")[0]
curr_use_case = use_case[app_mode].split(".")[0]
st.markdown("State-of-the-art model for salient object detection
", unsafe_allow_html=True)
st.markdown(f"Use cases for salient object detection
• {use_case['1']}
• {use_case['2']}
", unsafe_allow_html=True)
st.markdown(f"views: {get_views('init')}
", unsafe_allow_html=True)
try:
with st.form('twc_form'):
step1_line = "Upload an image or choose an example image below"
uploaded_file = st.file_uploader(step1_line, type=["png","jpg","jpeg"])
selected_file_name = image_select("Select image", ["twc_samples/sample1.jpg", "twc_samples/sample2.jpg", "twc_samples/sample3.jpg", "twc_samples/sample4.jpg"])
st.write("")
mask_type = st.selectbox(label=f'Select type of masking',
options = list(dict.keys(mask_types)), index=0, key = "twc_mask_types")
mask_type = mask_types[mask_type]
st.write("")
submit_button = st.form_submit_button('Run')
options_arr,markdown_str = construct_model_info_for_display(model_names)
input_status_area = st.empty()
display_area = st.empty()
if submit_button:
start = time.time()
if uploaded_file is not None:
st.session_state["file_name"] = uploaded_file.name
else:
st.session_state["file_name"] = selected_file_name
st.session_state["mask_type"] = mask_type
display_area.empty()
results = run_test(config,st.session_state["file_name"],display_area,uploaded_file,mask_type)
with display_area.container():
if ("error" in results):
st.error(results["error"])
else:
device = 'GPU' if torch.cuda.is_available() else 'CPU'
response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for image size: {results['size']} bytes"
display_results(results,response_info,mask_type)
#st.json(results)
st.download_button(
label="Download results as png",
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "",
disabled = False if st.session_state["download_ready"] != None else True,
file_name= (st.session_state["model_name"] + "_" + st.session_state["mask_type"] + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".png").replace("/","_"),
mime='image/png',
key ="download"
)
except Exception as e:
st.error("Some error occurred during loading" + str(e))
st.stop()
st.markdown(markdown_str, unsafe_allow_html=True)
if __name__ == "__main__":
app_main("1","sod_app_examples.json","sod_app_models.json","config.json")