|
import time |
|
import sys |
|
import streamlit as st |
|
import string |
|
import os |
|
from io import StringIO |
|
import pdb |
|
import json |
|
import torch |
|
import requests |
|
import socket |
|
from streamlit_image_select import image_select |
|
|
|
|
|
|
|
|
|
|
|
use_case = {"1":"Image background removal - (upload any picture and remove background)","2":"Masking foreground for downstream inpainting task"} |
|
mask_types = { |
|
"rgba - makes background white":"rgba", |
|
"green - makes the background green":"green", |
|
"blur - blurs background":"blur", |
|
"map - makes the foreground white and rest black ":"map" |
|
} |
|
|
|
|
|
|
|
|
|
APP_NAME = "hf/salient_object_detection" |
|
INFO_URL = "https://www.taskswithcode.com/stats/" |
|
TMP_DIR="tmp_dir" |
|
TMP_SEED = 1 |
|
|
|
|
|
|
|
|
|
|
|
def get_views(action): |
|
ret_val = 0 |
|
|
|
hostname = socket.gethostname() |
|
ip_address = socket.gethostbyname(hostname) |
|
if ("view_count" not in st.session_state): |
|
try: |
|
app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address} |
|
res = requests.post(INFO_URL, json = app_info).json() |
|
print(res) |
|
data = res["count"] |
|
except: |
|
data = 0 |
|
ret_val = data |
|
st.session_state["view_count"] = data |
|
else: |
|
ret_val = st.session_state["view_count"] |
|
if (action != "init"): |
|
app_info = {'name': APP_NAME,"action":action,"host":hostname,"ip":ip_address} |
|
res = requests.post(INFO_URL, json = app_info).json() |
|
return "{:,}".format(ret_val) |
|
|
|
|
|
|
|
|
|
def construct_model_info_for_display(model_names): |
|
options_arr = [] |
|
|
|
markdown_str = f"<div style=\"font-size:16px; color: #2f2f2f; text-align: left\"><br/><b>Model evaluated </b><br/></div>" |
|
markdown_str += f"<div style=\"font-size:2px; color: #2f2f2f; text-align: left\"><br/></div>" |
|
for node in model_names: |
|
options_arr .append(node["name"]) |
|
if (node["mark"] == "True"): |
|
markdown_str += f"<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"> • Model: <a href=\'{node['paper_url']}\' target='_blank'>{node['name']}</a><br/> Code released by: <a href=\'{node['orig_author_url']}\' target='_blank'>{node['orig_author']}</a><br/> Model info: <a href=\'{node['sota_info']['sota_link']}\' target='_blank'>{node['sota_info']['task']}</a></div>" |
|
if ("Note" in node): |
|
markdown_str += f"<div style=\"font-size:16px; color: #a91212; text-align: left\"> {node['Note']}<a href=\'{node['alt_url']}\' target='_blank'>link</a></div>" |
|
markdown_str += "<div style=\"font-size:16px; color: #5f5f5f; text-align: left\"><br/></div>" |
|
|
|
markdown_str += "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><b>Note:</b><br/>• Uploaded files are loaded into non-persistent memory for the duration of the computation. They are not cached</div>" |
|
markdown_str += "<div style=\"font-size:12px; color: #9f9f9f; text-align: left\"><br/><a href=\'https://github.com/taskswithcode/salient_object_detection_app.git\' target='_blank'>Github code</a> for this app</div>" |
|
return options_arr,markdown_str |
|
|
|
|
|
def init_page(): |
|
st.set_page_config(page_title='TWC - State-of-the-art model salient object detection (visually dominant objects in an image)', page_icon="logo.jpg", layout='centered', initial_sidebar_state='auto', |
|
menu_items={ |
|
'About': 'This app was created by taskswithcode. http://taskswithcode.com' |
|
|
|
}) |
|
col,pad = st.columns([85,15]) |
|
|
|
with col: |
|
st.image("long_form_logo_with_icon.png") |
|
|
|
|
|
def run_test(config,input_file_name,display_area,uploaded_file,mask_type): |
|
global TMP_SEED |
|
display_area.text("Processing request...") |
|
try: |
|
if (uploaded_file is None): |
|
file_data = open(input_file_name, "rb") |
|
r = requests.post(config["SERVER_ADDRESS"], data={"mask":mask_type}, files={"test":file_data}) |
|
else: |
|
file_data = uploaded_file.read() |
|
file_name = f"{TMP_DIR}/{TMP_SEED}_{str(time.time()).replace('.','_')}_{uploaded_file.name}" |
|
TMP_SEED += 1 |
|
with open(file_name,"wb") as fp: |
|
fp.write(file_data) |
|
file_data = open(file_name, "rb") |
|
r = requests.post(config["SERVER_ADDRESS"], data={"mask":mask_type}, files={"test":file_data}) |
|
os.remove(file_name) |
|
print("Servers response:",r.status_code,len(r.content)) |
|
if (r.status_code == 200): |
|
size = "{:,}".format(len(r.content)) |
|
return {"response":r.content,"size":size} |
|
else: |
|
return {"error":f"API request failed {r.status_code}"} |
|
except Exception as e: |
|
st.error("Some error occurred during prediction" + str(e)) |
|
st.stop() |
|
return {"error":f"Exception in performing salient object detection: {str(e)}"} |
|
return {} |
|
|
|
|
|
|
|
|
|
def display_results(results,response_info,mask): |
|
main_sent = f"<div style=\"font-size:14px; color: #2f2f2f; text-align: left\">{response_info}<br/><br/></div>" |
|
body_sent = [] |
|
download_data = {} |
|
main_sent = main_sent + "\n" + '\n'.join(body_sent) |
|
st.markdown(main_sent,unsafe_allow_html=True) |
|
st.image(results["response"], caption=f'Output of salient object detection with mask: {mask}') |
|
st.session_state["download_ready"] = results["response"] |
|
get_views("submit") |
|
|
|
|
|
def init_session(): |
|
print("Init session") |
|
init_page() |
|
st.session_state["model_name"] = "insprynet" |
|
st.session_state["download_ready"] = None |
|
st.session_state["model_name"] = "ss_test" |
|
st.session_state["file_name"] = "default" |
|
st.session_state["mask_type"] = "rgba" |
|
|
|
def app_main(app_mode,example_files,model_name_files,config_file): |
|
init_session() |
|
with open(example_files) as fp: |
|
example_file_names = json.load(fp) |
|
with open(model_name_files) as fp: |
|
model_names = json.load(fp) |
|
with open(config_file) as fp: |
|
config = json.load(fp) |
|
curr_use_case = use_case[app_mode].split(".")[0] |
|
curr_use_case = use_case[app_mode].split(".")[0] |
|
st.markdown("<h5 style='text-align: center;'>State-of-the-art model for salient object detection</h5>", unsafe_allow_html=True) |
|
st.markdown(f"<div style='color: #4f4f4f; text-align: left'>Use cases for salient object detection<br/> • {use_case['1']}<br/> • {use_case['2']}</div>", unsafe_allow_html=True) |
|
st.markdown(f"<div style='color: #9f9f9f; text-align: right'>views: {get_views('init')}</div>", unsafe_allow_html=True) |
|
|
|
|
|
try: |
|
|
|
|
|
with st.form('twc_form'): |
|
|
|
step1_line = "Upload an image or choose an example image below" |
|
uploaded_file = st.file_uploader(step1_line, type=["png","jpg","jpeg"]) |
|
|
|
selected_file_name = image_select("Select image", ["twc_samples/sample1.jpg", "twc_samples/sample2.jpg", "twc_samples/sample3.jpg", "twc_samples/sample4.jpg"]) |
|
|
|
|
|
st.write("") |
|
mask_type = st.selectbox(label=f'Select type of masking', |
|
options = list(dict.keys(mask_types)), index=0, key = "twc_mask_types") |
|
mask_type = mask_types[mask_type] |
|
st.write("") |
|
submit_button = st.form_submit_button('Run') |
|
options_arr,markdown_str = construct_model_info_for_display(model_names) |
|
|
|
|
|
input_status_area = st.empty() |
|
display_area = st.empty() |
|
if submit_button: |
|
start = time.time() |
|
if uploaded_file is not None: |
|
st.session_state["file_name"] = uploaded_file.name |
|
else: |
|
st.session_state["file_name"] = selected_file_name |
|
st.session_state["mask_type"] = mask_type |
|
display_area.empty() |
|
results = run_test(config,st.session_state["file_name"],display_area,uploaded_file,mask_type) |
|
with display_area.container(): |
|
if ("error" in results): |
|
st.error(results["error"]) |
|
else: |
|
device = 'GPU' if torch.cuda.is_available() else 'CPU' |
|
response_info = f"Computation time on {device}: {time.time() - start:.2f} secs for image size: {results['size']} bytes" |
|
display_results(results,response_info,mask_type) |
|
|
|
st.download_button( |
|
label="Download results as png", |
|
data= st.session_state["download_ready"] if st.session_state["download_ready"] != None else "", |
|
disabled = False if st.session_state["download_ready"] != None else True, |
|
file_name= (st.session_state["model_name"] + "_" + st.session_state["mask_type"] + "_" + '_'.join(st.session_state["file_name"].split(".")[:-1]) + ".png").replace("/","_"), |
|
mime='image/png', |
|
key ="download" |
|
) |
|
|
|
|
|
|
|
except Exception as e: |
|
st.error("Some error occurred during loading" + str(e)) |
|
st.stop() |
|
|
|
st.markdown(markdown_str, unsafe_allow_html=True) |
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
app_main("1","sod_app_examples.json","sod_app_models.json","config.json") |
|
|
|
|